#2869 - More YAML/test fixes to address failures
This commit is contained in:
@@ -260,11 +260,6 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
|
||||
# @property
|
||||
# def most_recent_action(self) -> ActType:
|
||||
# """Convenience method to access the agents most recent action."""
|
||||
# return self._most_recent_action
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
@@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
)
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
if agent.action_masking:
|
||||
if agent.config.action_masking:
|
||||
self.observation_space[agent_name] = spaces.Dict(
|
||||
{
|
||||
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
|
||||
@@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if agent.action_masking:
|
||||
if agent.config.action_masking:
|
||||
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
|
||||
else:
|
||||
all_obs[agent_name] = obs
|
||||
@@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
self.observation_space = spaces.Dict(
|
||||
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
|
||||
)
|
||||
@@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
|
||||
return new_obs, *_
|
||||
|
||||
@@ -60,6 +60,9 @@ agents:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
action_probabilities:
|
||||
0: 0.4
|
||||
1: 0.6
|
||||
|
||||
simulation:
|
||||
network:
|
||||
|
||||
@@ -85,6 +85,9 @@ agents:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
action_probabilities:
|
||||
0: 0.4
|
||||
1: 0.6
|
||||
|
||||
|
||||
simulation:
|
||||
|
||||
@@ -92,6 +92,9 @@ agents:
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
agent_settings:
|
||||
flatten_obs: True
|
||||
action_masking: False
|
||||
|
||||
simulation:
|
||||
network:
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_application_install_uninstall_on_uc2():
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
env.agent.flatten_obs = False
|
||||
env.agent.config.flatten_obs = False
|
||||
env.reset()
|
||||
|
||||
_, _, _, _, _ = env.step(0)
|
||||
|
||||
@@ -13,7 +13,7 @@ DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yam
|
||||
def env_with_ssh() -> PrimaiteGymEnv:
|
||||
"""Build data manipulation environment with SSH port open on router."""
|
||||
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
|
||||
env.agent.flatten_obs = False
|
||||
env.agent.config.agent_settings.flatten_obs = False
|
||||
router: Router = env.game.simulation.network.get_node_by_hostname("router_1")
|
||||
router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3)
|
||||
return env
|
||||
|
||||
Reference in New Issue
Block a user