diff --git a/src/primaite/game/agent/scripted_agents/interface.py b/src/primaite/game/agent/scripted_agents/interface.py index ab78eee0..e0dc61f2 100644 --- a/src/primaite/game/agent/scripted_agents/interface.py +++ b/src/primaite/game/agent/scripted_agents/interface.py @@ -260,11 +260,6 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False action_masking: bool = agent_settings.action_masking if agent_settings else False - # @property - # def most_recent_action(self) -> ActType: - # """Convenience method to access the agents most recent action.""" - # return self._most_recent_action - def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 2d540237..5d15ffa2 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] - if agent.action_masking: + if agent.config.action_masking: self.observation_space[agent_name] = spaces.Dict( { "action_mask": spaces.MultiBinary(agent.action_manager.space.n), @@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation obs = gymnasium.spaces.flatten(unflat_space, unflat_obs) - if agent.action_masking: + if agent.config.action_masking: all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs} else: all_obs[agent_name] = obs @@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env): self.env = PrimaiteGymEnv(env_config=env_config) # self.env.episode_counter -= 1 self.action_space = self.env.action_space - if self.env.agent.action_masking: + if self.env.agent.config.agent_settings.action_masking: self.observation_space = spaces.Dict( {"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space} ) @@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" super().reset() # Ensure PRNG seed is set everywhere - if self.env.agent.action_masking: + if self.env.agent.config.action_masking: obs, *_ = self.env.reset(seed=seed) new_obs = {"action_mask": self.env.action_masks(), "observations": obs} return new_obs, *_ @@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env): def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" # if action masking is enabled, intercept the step method and add action mask to observation - if self.env.agent.action_masking: + if self.env.agent.config.action_masking: obs, *_ = self.env.step(action) new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs} return new_obs, *_ diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml index 0253a4d2..e37a67da 100644 --- a/tests/assets/configs/basic_firewall.yaml +++ b/tests/assets/configs/basic_firewall.yaml @@ -60,6 +60,9 @@ agents: start_step: 5 frequency: 4 variance: 3 + action_probabilities: + 0: 0.4 + 1: 0.6 simulation: network: diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 52316260..d560efa3 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -85,6 +85,9 @@ agents: start_step: 5 frequency: 4 variance: 3 + action_probabilities: + 0: 0.4 + 1: 0.6 simulation: diff --git a/tests/assets/configs/install_and_configure_apps.yaml b/tests/assets/configs/install_and_configure_apps.yaml index 6b548f7e..18a9724b 100644 --- a/tests/assets/configs/install_and_configure_apps.yaml +++ b/tests/assets/configs/install_and_configure_apps.yaml @@ -92,6 +92,9 @@ agents: reward_function: reward_components: - type: DUMMY + agent_settings: + flatten_obs: True + action_masking: False simulation: network: diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 7ec38d72..1cf2ceea 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -49,7 +49,7 @@ def test_application_install_uninstall_on_uc2(): cfg = yaml.safe_load(f) env = PrimaiteGymEnv(env_config=cfg) - env.agent.flatten_obs = False + env.agent.config.flatten_obs = False env.reset() _, _, _, _, _ = env.step(0) diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py index e7287eee..b7af3ec8 100644 --- a/tests/integration_tests/game_layer/observations/test_user_observations.py +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -13,7 +13,7 @@ DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yam def env_with_ssh() -> PrimaiteGymEnv: """Build data manipulation environment with SSH port open on router.""" env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) - env.agent.flatten_obs = False + env.agent.config.agent_settings.flatten_obs = False router: Router = env.game.simulation.network.get_node_by_hostname("router_1") router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3) return env