From 1798ec6fe0168807d2ed1b2f84d8ccaf0198de89 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 5 Dec 2024 14:00:44 +0000 Subject: [PATCH] #2869 - Commit before switching branches. Changes to make pydantic happy with AgentLog --- src/primaite/game/agent/agent_log.py | 18 +++++++++++----- src/primaite/game/agent/interface.py | 11 ++++++++-- .../agent/scripted_agents/abstract_tap.py | 9 ++++++++ .../scripted_agents/data_manipulation_bot.py | 12 +++++------ .../scripted_agents/probabilistic_agent.py | 1 + .../agent/scripted_agents/random_agent.py | 13 +++++++++--- .../game/agent/scripted_agents/tap001.py | 21 ++++++------------- src/primaite/game/game.py | 1 - tests/conftest.py | 1 + 9 files changed, 54 insertions(+), 33 deletions(-) diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index 62ef4884..c292ba4f 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -1,8 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import logging +from abc import ABC from pathlib import Path +from typing import Optional from prettytable import MARKDOWN, PrettyTable +from pydantic import BaseModel from primaite.simulator import LogLevel, SIM_OUTPUT @@ -18,22 +21,27 @@ class _NotJSONFilter(logging.Filter): return not record.getMessage().startswith("{") and not record.getMessage().endswith("}") -class AgentLog: +class AgentLog(BaseModel): """ A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent. Each log message is written to a file located at: /agent_name/agent_name.log """ - def __init__(self, agent_name: str): + agent_name: str = "unnamed_agent" + current_episode: int = 1 + current_timestep: int = 0 + + def __init__(self, agent_name: Optional[str]): """ Constructs a Agent Log instance for a given hostname. :param hostname: The hostname associated with the system logs being recorded. """ - self.agent_name = agent_name - self.current_episode: int = 1 - self.current_timestep: int = 0 + super().__init__() + self.agent_name = agent_name or "unnamed_agent" + # self.current_episode: int = 1 + # self.current_timestep: int = 0 self.setup_logger() @property diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 402c7ce2..1b9dbcd6 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -97,6 +97,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} config: "AbstractAgent.ConfigSchema" + agent_name = "Abstract_Agent" class ConfigSchema(BaseModel): """ @@ -115,13 +116,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): """ type: str - agent_name: ClassVar[str] + agent_name: str = "Abstact_Agent" logger: AgentLog = AgentLog(agent_name) history: List[AgentHistoryItem] = [] action_manager: Optional[ActionManager] = None observation_manager: Optional[ObservationManager] = None reward_function: Optional[RewardFunction] = None - agent_settings = Optional[AgentSettings] = None + agent_settings: Optional[AgentSettings] = None def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -213,6 +214,11 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"): """Base class for actors which generate their own behaviour.""" + class ConfigSchema(AbstractAgent.ConfigSchema): + """Configuration Schema for AbstractScriptedAgents.""" + + agent_name: str = "Abstract_Scripted_Agent" + @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """Return an action to be taken in the environment.""" @@ -227,6 +233,7 @@ class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"): class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Proxy Agent.""" + agent_name: str = "Proxy_Agent" agent_settings = Union[AgentSettings | None] = None most_reason_action: ActType flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py index 2523f9f7..19eeac1a 100644 --- a/src/primaite/game/agent/scripted_agents/abstract_tap.py +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -15,6 +15,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration schema for Abstract TAP agents.""" + agent_name: str = "Abstract_TAP" starting_node_name: str next_execution_timestep: int @@ -32,3 +33,11 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): -self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance ) self.config.next_execution_timestep = timestep + random_timestep_increment + + def _select_start_node(self) -> None: + """Set the starting starting node of the agent to be a random node from this agent's action manager.""" + # we are assuming that every node in the node manager has a data manipulation application at idx 0 + num_nodes = len(self.config.action_manager.node_names) + starting_node_idx = random.randint(0, num_nodes - 1) + self.starting_node_name = self.config.action_manager.node_names[starting_node_idx] + self.config.logger.debug(f"Selected Starting node ID: {self.starting_node_name}") diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index b375da66..3a2dbdd2 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -16,6 +16,11 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen """Configuration Schema for DataManipulationAgent.""" starting_application_name: str + agent_name: str = "Data_Manipulation_Agent" + + def __init__(self) -> None: + """Meh.""" + self.setup_agent() @property def next_execution_timestep(self) -> int: @@ -52,10 +57,3 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen """Set the next execution timestep when the episode resets.""" self._select_start_node() self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step) - - def _select_start_node(self) -> None: - """Set the starting starting node of the agent to be a random node from this agent's action manager.""" - # we are assuming that every node in the node manager has a data manipulation application at idx 0 - num_nodes = len(self.action_manager.node_names) - self.starting_node_idx = random.randint(0, num_nodes - 1) - self.config.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}") diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index 02ac5931..c29719ac 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -18,6 +18,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent" class ConfigSchema(pydantic.BaseModel): """Config schema for Probabilistic agent settings.""" + agent_name: str = "Probabilistic_Agent" model_config = pydantic.ConfigDict(extra="forbid") """Strict validation.""" action_probabilities: Dict[int, float] diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index a9082eda..e11e3352 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -10,7 +10,12 @@ from primaite.game.agent.interface import AbstractScriptedAgent class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration Schema for Random Agents.""" + + agent_name = "Random_Agent" + + def get_action(self) -> Tuple[str, Dict]: """Sample the action space randomly. :param obs: Current observation for this agent, not used in RandomAgent @@ -31,6 +36,8 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"): class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration Schema for Periodic Agent.""" + agent_name = "Periodic_Agent" + """Name of the agent.""" start_step: int = 20 "The timestep at which an agent begins performing it's actions." start_variance: int = 5 @@ -69,9 +76,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"): def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: """Do nothing, unless the current timestep is the next execution timestep, in which case do the action.""" - if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions: + if timestep == self.next_execution_timestep and self.num_executions < self.config.max_executions: self.num_executions += 1 - self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance) + self._set_next_execution_timestep(timestep + self.config.frequency, self.config.variance) return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} return "DONOTHING", {} diff --git a/src/primaite/game/agent/scripted_agents/tap001.py b/src/primaite/game/agent/scripted_agents/tap001.py index d3a82bbe..3b7abe50 100644 --- a/src/primaite/game/agent/scripted_agents/tap001.py +++ b/src/primaite/game/agent/scripted_agents/tap001.py @@ -19,20 +19,19 @@ class TAP001(AbstractTAPAgent, identifier="TAP001"): class ConfigSchema(AbstractTAPAgent.ConfigSchema): """Configuration Schema for TAP001 Agent.""" + agent_name: str = "TAP001" installed: bool = False + def __init__(self) -> None: + """___init___ bruv. Restecpa.""" + super().__init__() + self.setup_agent() + @property def starting_node_name(self) -> str: """Node that TAP001 starts from.""" return self.config.starting_node_name - @classmethod - def from_config(cls, config: Dict) -> TAP001: - """Override the base from_config method to ensure successful agent setup.""" - obj: TAP001 = cls(config=cls.ConfigSchema(**config)) - obj.setup_agent() - return obj - def get_action(self, timestep: int) -> Tuple[str, Dict]: """Waits until a specific timestep, then attempts to execute the ransomware application. @@ -72,11 +71,3 @@ class TAP001(AbstractTAPAgent, identifier="TAP001"): self.ip_address = act[1]["ip_address"] return raise RuntimeError("TAP001 agent could not find database server ip address in action map") - - def _select_start_node(self) -> None: - """Set the starting starting node of the agent to be a random node from this agent's action manager.""" - # we are assuming that every node in the node manager has a data manipulation application at idx 0 - num_nodes = len(self.config.action_manager.node_names) - starting_node_idx = random.randint(0, num_nodes - 1) - self.starting_node_name = self.config.action_manager.node_names[starting_node_idx] - self.config.logger.debug(f"Selected Starting node ID: {self.starting_node_name}") diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 79587e47..9ef75fb9 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -558,7 +558,6 @@ class PrimaiteGame: msg = f"Configuration error: {agent_type} is not a valid agent type." _LOGGER.error(msg) raise ValueError(msg) - game.agents[agent_cfg["ref"]] = new_agent # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. diff --git a/tests/conftest.py b/tests/conftest.py index b24c4c76..27032540 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -272,6 +272,7 @@ class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"): class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Abstract Agent used in tests.""" + agent_name: str = "Controlled_Agent" most_recent_action: Tuple[str, Dict] def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: