#2869 - Changes to AbstractAgent to address some pydantic issues
This commit is contained in:
@@ -42,7 +42,6 @@ AbstractAgent
|
||||
|
||||
Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class ExampleAgent(AbstractAgent, identifier = "example_agent"):
|
||||
@@ -63,7 +62,7 @@ AbstractAgent
|
||||
|
||||
#. **identifier**:
|
||||
|
||||
All agent classes should have a unique ``identifier`` attribute, for when they are added to the base ``AbstractAgent`` registry. PrimAITE notation is for these to be written in snake_case
|
||||
All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry.
|
||||
|
||||
Changes to YAML file
|
||||
====================
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
|
||||
@@ -20,28 +19,23 @@ class _NotJSONFilter(logging.Filter):
|
||||
return not record.getMessage().startswith("{") and not record.getMessage().endswith("}")
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
class AgentLog:
|
||||
"""
|
||||
A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent.
|
||||
|
||||
Each log message is written to a file located at: <simulation output directory>/agent_name/agent_name.log
|
||||
"""
|
||||
|
||||
agent_name: str = "unnamed_agent"
|
||||
current_episode: int = 1
|
||||
current_timestep: int = 0
|
||||
logger: logging
|
||||
|
||||
def __init__(self, agent_name: Optional[str]):
|
||||
"""
|
||||
Constructs a Agent Log instance for a given hostname.
|
||||
|
||||
:param hostname: The hostname associated with the system logs being recorded.
|
||||
:param agent_name: The agent_name associated with the system logs being recorded.
|
||||
"""
|
||||
super().__init__()
|
||||
self.agent_name = agent_name or "unnamed_agent"
|
||||
# self.current_episode: int = 1
|
||||
# self.current_timestep: int = 0
|
||||
self.agent_name = agent_name if agent_name else "unnamed_agent"
|
||||
self.current_timestep: int = 0
|
||||
self.current_episode: int = 0
|
||||
self.setup_logger()
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""Interface for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -92,14 +92,12 @@ class AgentSettings(BaseModel):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
class AbstractAgent(BaseModel, identifier="Abstract_Agent"):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
agent_name: str = "Abstact_Agent"
|
||||
logger: AgentLog = AgentLog(agent_name)
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""
|
||||
@@ -117,11 +115,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
:type agent_settings: Optional[AgentSettings]
|
||||
"""
|
||||
|
||||
agent_name: str = "Abstract_Agent"
|
||||
history: List[AgentHistoryItem] = []
|
||||
action_manager: Optional[ActionManager] = None
|
||||
observation_manager: Optional[ObservationManager] = None
|
||||
reward_function: Optional[RewardFunction] = None
|
||||
agent_settings: Optional[AgentSettings] = None
|
||||
_logger: AgentLog = AgentLog(agent_name=agent_name)
|
||||
_action_manager: Optional[ActionManager] = None
|
||||
_observation_manager: Optional[ObservationManager] = None
|
||||
_reward_function: Optional[RewardFunction] = None
|
||||
_agent_settings: Optional[AgentSettings] = None
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -132,22 +132,22 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
@property
|
||||
def logger(self) -> AgentLog:
|
||||
"""Return the AgentLog."""
|
||||
return self.config.logger
|
||||
return self.config._logger
|
||||
|
||||
@property
|
||||
def observation_manager(self) -> ObservationManager:
|
||||
"""Returns the agents observation manager."""
|
||||
return self.config.observation_manager
|
||||
return self.config._observation_manager
|
||||
|
||||
@property
|
||||
def action_manager(self) -> ActionManager:
|
||||
"""Returns the agents action manager."""
|
||||
return self.config.action_manager
|
||||
return self.config._action_manager
|
||||
|
||||
@property
|
||||
def reward_function(self) -> RewardFunction:
|
||||
"""Returns the agents reward function."""
|
||||
return self.config.reward_function
|
||||
return self.config._reward_function
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
|
||||
@@ -42,11 +42,11 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
self.config.logger.debug(msg="Performing do nothing action")
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
|
||||
self.config.logger.info(msg="Performing a data manipulation attack!")
|
||||
self._set_next_execution_timestep(timestep + self.config._agent_settings.start_settings.frequency)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "node_application_execute", {
|
||||
"node_name": self.config.starting_node_name,
|
||||
"application_name": self.config.starting_application_name,
|
||||
@@ -55,4 +55,4 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
|
||||
self._set_next_execution_timestep(self.config._agent_settings.start_settings.start_step)
|
||||
|
||||
Reference in New Issue
Block a user