#2869 - More YAML/test fixes to address failures

This commit is contained in:
Charlie Crane
2024-12-17 12:47:54 +00:00
parent 3b1b74fb3a
commit 770896200b
7 changed files with 16 additions and 12 deletions

View File

@@ -260,11 +260,6 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
action_masking: bool = agent_settings.action_masking if agent_settings else False
# @property
# def most_recent_action(self) -> ActType:
# """Convenience method to access the agents most recent action."""
# return self._most_recent_action
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.

View File

@@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
)
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
if agent.action_masking:
if agent.config.action_masking:
self.observation_space[agent_name] = spaces.Dict(
{
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
@@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
if agent.action_masking:
if agent.config.action_masking:
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
else:
all_obs[agent_name] = obs
@@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env):
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
if self.env.agent.action_masking:
if self.env.agent.config.agent_settings.action_masking:
self.observation_space = spaces.Dict(
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
)
@@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
if self.env.agent.config.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
return new_obs, *_
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
if self.env.agent.config.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
return new_obs, *_

View File

@@ -60,6 +60,9 @@ agents:
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.4
1: 0.6
simulation:
network:

View File

@@ -85,6 +85,9 @@ agents:
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.4
1: 0.6
simulation:

View File

@@ -92,6 +92,9 @@ agents:
reward_function:
reward_components:
- type: DUMMY
agent_settings:
flatten_obs: True
action_masking: False
simulation:
network:

View File

@@ -49,7 +49,7 @@ def test_application_install_uninstall_on_uc2():
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
env.agent.flatten_obs = False
env.agent.config.flatten_obs = False
env.reset()
_, _, _, _, _ = env.step(0)

View File

@@ -13,7 +13,7 @@ DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yam
def env_with_ssh() -> PrimaiteGymEnv:
"""Build data manipulation environment with SSH port open on router."""
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
env.agent.flatten_obs = False
env.agent.config.agent_settings.flatten_obs = False
router: Router = env.game.simulation.network.get_node_by_hostname("router_1")
router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3)
return env