From c3a70be8d14ddc64c7a31a58b8d2354dfe5f529f Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Dec 2024 16:37:39 +0000 Subject: [PATCH] #2869 - Changes to AbstractAgent to address some pydantic issues --- .../how_to_guides/extensible_agents.rst | 3 +-- src/primaite/game/agent/agent_log.py | 16 ++++--------- src/primaite/game/agent/interface.py | 24 +++++++++---------- .../scripted_agents/data_manipulation_bot.py | 8 +++---- 4 files changed, 22 insertions(+), 29 deletions(-) diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst index 718ea09a..b694f882 100644 --- a/docs/source/how_to_guides/extensible_agents.rst +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -42,7 +42,6 @@ AbstractAgent Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*. - .. code-block:: python class ExampleAgent(AbstractAgent, identifier = "example_agent"): @@ -63,7 +62,7 @@ AbstractAgent #. **identifier**: - All agent classes should have a unique ``identifier`` attribute, for when they are added to the base ``AbstractAgent`` registry. PrimAITE notation is for these to be written in snake_case + All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry. Changes to YAML file ==================== diff --git a/src/primaite/game/agent/agent_log.py b/src/primaite/game/agent/agent_log.py index 6eaf9e73..f12c49f7 100644 --- a/src/primaite/game/agent/agent_log.py +++ b/src/primaite/game/agent/agent_log.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Optional from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel from primaite.simulator import LogLevel, SIM_OUTPUT @@ -20,28 +19,23 @@ class _NotJSONFilter(logging.Filter): return not record.getMessage().startswith("{") and not record.getMessage().endswith("}") -class AgentLog(BaseModel): +class AgentLog: """ A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent. Each log message is written to a file located at: /agent_name/agent_name.log """ - agent_name: str = "unnamed_agent" - current_episode: int = 1 - current_timestep: int = 0 - logger: logging - def __init__(self, agent_name: Optional[str]): """ Constructs a Agent Log instance for a given hostname. - :param hostname: The hostname associated with the system logs being recorded. + :param agent_name: The agent_name associated with the system logs being recorded. """ super().__init__() - self.agent_name = agent_name or "unnamed_agent" - # self.current_episode: int = 1 - # self.current_timestep: int = 0 + self.agent_name = agent_name if agent_name else "unnamed_agent" + self.current_timestep: int = 0 + self.current_episode: int = 0 self.setup_logger() @property diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 5bef1076..0c208f71 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -2,7 +2,7 @@ """Interface for agents.""" from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union from gymnasium.core import ActType, ObsType @@ -92,14 +92,12 @@ class AgentSettings(BaseModel): return cls(**config) -class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): +class AbstractAgent(BaseModel, identifier="Abstract_Agent"): """Base class for scripted and RL agents.""" _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} config: "AbstractAgent.ConfigSchema" - agent_name: str = "Abstact_Agent" - logger: AgentLog = AgentLog(agent_name) class ConfigSchema(BaseModel): """ @@ -117,11 +115,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): :type agent_settings: Optional[AgentSettings] """ + agent_name: str = "Abstract_Agent" history: List[AgentHistoryItem] = [] - action_manager: Optional[ActionManager] = None - observation_manager: Optional[ObservationManager] = None - reward_function: Optional[RewardFunction] = None - agent_settings: Optional[AgentSettings] = None + _logger: AgentLog = AgentLog(agent_name=agent_name) + _action_manager: Optional[ActionManager] = None + _observation_manager: Optional[ObservationManager] = None + _reward_function: Optional[RewardFunction] = None + _agent_settings: Optional[AgentSettings] = None def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -132,22 +132,22 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): @property def logger(self) -> AgentLog: """Return the AgentLog.""" - return self.config.logger + return self.config._logger @property def observation_manager(self) -> ObservationManager: """Returns the agents observation manager.""" - return self.config.observation_manager + return self.config._observation_manager @property def action_manager(self) -> ActionManager: """Returns the agents action manager.""" - return self.config.action_manager + return self.config._action_manager @property def reward_function(self) -> RewardFunction: """Returns the agents reward function.""" - return self.config.reward_function + return self.config._reward_function @classmethod def from_config(cls, config: Dict) -> "AbstractAgent": diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 247e815a..5927cd09 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -42,11 +42,11 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - self.config.logger.debug(msg="Performing do nothing action") + self.logger.debug(msg="Performing do nothing action") return "do_nothing", {} - self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency) - self.config.logger.info(msg="Performing a data manipulation attack!") + self._set_next_execution_timestep(timestep + self.config._agent_settings.start_settings.frequency) + self.logger.info(msg="Performing a data manipulation attack!") return "node_application_execute", { "node_name": self.config.starting_node_name, "application_name": self.config.starting_application_name, @@ -55,4 +55,4 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen def setup_agent(self) -> None: """Set the next execution timestep when the episode resets.""" self._select_start_node() - self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step) + self._set_next_execution_timestep(self.config._agent_settings.start_settings.start_step)