From 436a986458ccaa930b877d64ded7c9031c01f525 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Dec 2024 10:51:57 +0000 Subject: [PATCH] #2869 - Fixed failing tests from agent refactor. Some tests still fail but this is due to updating some action names in anticipation of merging in the extensible actions refactor --- .../agent/scripted_agents/abstract_tap.py | 9 +++++++-- .../scripted_agents/data_manipulation_bot.py | 8 ++------ .../game/agent/scripted_agents/interface.py | 19 ++++++++++++++----- .../assets/configs/test_primaite_session.yaml | 2 ++ 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py index 95769624..add29b03 100644 --- a/src/primaite/game/agent/scripted_agents/abstract_tap.py +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -15,13 +15,18 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): config: "AbstractTAPAgent.ConfigSchema" agent_name: str = "Abstract_TAP" - _next_execution_timestep: int + next_execution_timestep: int = 0 class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration schema for Abstract TAP agents.""" starting_node_name: Optional[str] = None + # @property + # def next_execution_timestep(self) -> int: + # """Returns the agents next execution timestep.""" + # return self.next_execution_timestep + @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """Return an action to be taken in the environment.""" @@ -40,7 +45,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): random_timestep_increment = random.randint( -self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance ) - self._next_execution_timestep = timestep + random_timestep_increment + self.next_execution_timestep = timestep + random_timestep_increment def _select_start_node(self) -> None: """Set the starting starting node of the agent to be a random node from this agent's action manager.""" diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 0f687367..84cad9f6 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -14,18 +14,14 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingA class ConfigSchema(AbstractTAPAgent.ConfigSchema): """Configuration Schema for DataManipulationAgent.""" + starting_application_name: Optional[str] = None # def __init__(self, **kwargs: Any) -> None: # """Initialise DataManipulationAgent.""" - # # self.setup_agent() + # self.setup_agent() # super().__init_subclass__(**kwargs) - @property - def next_execution_timestep(self) -> int: - """Returns the agents next execution timestep.""" - return self._next_execution_timestep - @property def starting_node_name(self) -> str: """Returns the agents starting node name.""" diff --git a/src/primaite/game/agent/scripted_agents/interface.py b/src/primaite/game/agent/scripted_agents/interface.py index 5e9167f5..045d6d12 100644 --- a/src/primaite/game/agent/scripted_agents/interface.py +++ b/src/primaite/game/agent/scripted_agents/interface.py @@ -135,6 +135,15 @@ class AbstractAgent(BaseModel): """Return the AgentLog.""" return self.config._logger + @property + def flatten_obs(self) -> bool: + return self.config.agent_settings.flatten_obs + + @property + def history(self) -> List[AgentHistoryItem]: + """Return the agent history""" + return self.config.history + @property def observation_manager(self) -> ObservationManager: """Returns the agents observation manager.""" @@ -236,7 +245,7 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): """Agent that sends observations to an RL model and receives actions from that model.""" config: "ProxyAgent.ConfigSchema" - _most_recent_action: ActType + most_recent_action: ActType = None class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Proxy Agent.""" @@ -246,10 +255,10 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False action_masking: bool = agent_settings.action_masking if agent_settings else False - @property - def most_recent_action(self) -> ActType: - """Convenience method to access the agents most recent action.""" - return self._most_recent_action + # @property + # def most_recent_action(self) -> ActType: + # """Convenience method to access the agents most recent action.""" + # return self._most_recent_action def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 27cfa240..cf241f3c 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -47,6 +47,8 @@ agents: start_step: 25 frequency: 20 variance: 5 + action_probabilities: + 0: 1.0 - ref: data_manipulation_attacker team: RED