From dc6f2be20932f716fd70b64c40c97d2963be8f4d Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Dec 2024 12:50:14 +0000 Subject: [PATCH] #2869 - pre-commit changes --- .../game/agent/scripted_agents/abstract_tap.py | 5 +++-- .../scripted_agents/data_manipulation_bot.py | 5 +++-- .../game/agent/scripted_agents/interface.py | 12 +++++------- .../agent/scripted_agents/probabilistic_agent.py | 7 +++++-- .../game/agent/scripted_agents/random_agent.py | 7 ++++--- tests/conftest.py | 9 ++++----- .../_game/_agent/test_probabilistic_agent.py | 16 +++++++++------- 7 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py index d30ba9b1..725d3525 100644 --- a/src/primaite/game/agent/scripted_agents/abstract_tap.py +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -7,9 +7,10 @@ from typing import Dict, Optional, Tuple from gymnasium.core import ObsType -from primaite.game.agent.scripted_agents.interface import AbstractAgent, AbstractScriptedAgent +from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent + +__all__ = "AbstractTAPAgent" -__all__ = ("AbstractTAPAgent") class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): """Base class for TAP agents to inherit from.""" diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 594f1b41..d6213f67 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -1,11 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Optional, Tuple from gymnasium.core import ObsType from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent -__all__ = ("DataManipulationAgent") +__all__ = "DataManipulationAgent" + class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingAgent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" diff --git a/src/primaite/game/agent/scripted_agents/interface.py b/src/primaite/game/agent/scripted_agents/interface.py index e0dc61f2..e6c2d6b3 100644 --- a/src/primaite/game/agent/scripted_agents/interface.py +++ b/src/primaite/game/agent/scripted_agents/interface.py @@ -17,11 +17,8 @@ from primaite.interface.request import RequestFormat, RequestResponse if TYPE_CHECKING: pass -__all__ = ("AgentHistoryItem", - "AgentStartSettings", - "AbstractAgent", - "AbstractScriptedAgent", - "ProxyAgent") +__all__ = ("AgentHistoryItem", "AgentStartSettings", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent") + class AgentHistoryItem(BaseModel): """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" @@ -121,7 +118,7 @@ class AbstractAgent(BaseModel): """ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - agent_name: ClassVar[str] = "Abstract_Agent" # TODO: Make this a ClassVar[str] like verb in actions? + agent_name: ClassVar[str] = "Abstract_Agent" # TODO: Make this a ClassVar[str] like verb in actions? history: List[AgentHistoryItem] = [] _logger: AgentLog = AgentLog(agent_name=agent_name) action_manager: ActionManager @@ -142,11 +139,12 @@ class AbstractAgent(BaseModel): @property def flatten_obs(self) -> bool: + """Return agent flatten_obs param.""" return self.config.agent_settings.flatten_obs @property def history(self) -> List[AgentHistoryItem]: - """Return the agent history""" + """Return the agent history.""" return self.config.history @property diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index ba6ba850..533f0628 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -8,7 +8,8 @@ from gymnasium.core import ObsType from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings -__all__ = ("ProbabilisticAgent") +__all__ = "ProbabilisticAgent" + class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"): """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" @@ -17,8 +18,11 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") rng: Any = np.random.default_rng(np.random.randint(0, 65535)) class AgentSettings(AgentSettings): + """ProbabilisticAgent settings.""" + action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" + @pydantic.field_validator("action_probabilities", mode="after") @classmethod def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: @@ -44,7 +48,6 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") agent_name: str = "ProbabilisticAgent" agent_settings: "ProbabilisticAgent.AgentSettings" - @property def probabilities(self) -> Dict[str, int]: """Convenience method to view the probabilities of the Agent.""" diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index fecc235f..fadaa66c 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -8,6 +8,7 @@ from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent __all__ = ("RandomAgent", "PeriodicAgent") + class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"): """Agent that ignores its observation and acts completely at random.""" @@ -45,11 +46,11 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"): "The timestep at which an agent begins performing it's actions." start_variance: int = 5 "Deviation around the start step." - + # TODO: This is available in config.agent_settings.start_settings.frequency frequency: int = 5 "The number of timesteps to wait between performing actions." - + # TODO: This is available in config.agent_settings.start_settings.variance variance: int = 0 "The amount the frequency can randomly change to." @@ -57,7 +58,7 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"): "Maximum number of times the agent can execute its action." num_executions: int = 0 """Number of times the agent has executed an action.""" - #TODO: Also in abstract_tap - move up and inherit? Add to AgentStartSettings? + # TODO: Also in abstract_tap - move up and inherit? Add to AgentStartSettings? next_execution_timestep: int = 0 """Timestep of the next action execution by the agent.""" diff --git a/tests/conftest.py b/tests/conftest.py index 68097830..319e306d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -497,12 +497,11 @@ def game_and_agent(): observation_space = ObservationManager(NestedObservation(components={})) reward_function = RewardFunction() - config = { - "agent_name":"test_agent", - "action_manager":action_space, - "observation_manager":observation_space, - "reward_function":reward_function, + "agent_name": "test_agent", + "action_manager": action_space, + "observation_manager": observation_space, + "reward_function": reward_function, } test_agent = ControlledAgent.from_config(config=config) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index 6e8c9c79..b6a49170 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -65,13 +65,15 @@ def test_probabilistic_agent(): # }, # ) - pa_config = {"agent_name":"test_agent", - "action_manager": action_space, - "observation_manager": observation_space, - "reward_function": reward_function, - "agent_settings": { - "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, - }} + pa_config = { + "agent_name": "test_agent", + "action_manager": action_space, + "observation_manager": observation_space, + "reward_function": reward_function, + "agent_settings": { + "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + }, + } pa = ProbabilisticAgent.from_config(config=pa_config)