#2869 - Commit before switching branches. Changes to make pydantic happy with AgentLog
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import logging
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
|
||||
@@ -18,22 +21,27 @@ class _NotJSONFilter(logging.Filter):
|
||||
return not record.getMessage().startswith("{") and not record.getMessage().endswith("}")
|
||||
|
||||
|
||||
class AgentLog:
|
||||
class AgentLog(BaseModel):
|
||||
"""
|
||||
A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent.
|
||||
|
||||
Each log message is written to a file located at: <simulation output directory>/agent_name/agent_name.log
|
||||
"""
|
||||
|
||||
def __init__(self, agent_name: str):
|
||||
agent_name: str = "unnamed_agent"
|
||||
current_episode: int = 1
|
||||
current_timestep: int = 0
|
||||
|
||||
def __init__(self, agent_name: Optional[str]):
|
||||
"""
|
||||
Constructs a Agent Log instance for a given hostname.
|
||||
|
||||
:param hostname: The hostname associated with the system logs being recorded.
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
self.current_episode: int = 1
|
||||
self.current_timestep: int = 0
|
||||
super().__init__()
|
||||
self.agent_name = agent_name or "unnamed_agent"
|
||||
# self.current_episode: int = 1
|
||||
# self.current_timestep: int = 0
|
||||
self.setup_logger()
|
||||
|
||||
@property
|
||||
|
||||
@@ -97,6 +97,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
agent_name = "Abstract_Agent"
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""
|
||||
@@ -115,13 +116,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
"""
|
||||
|
||||
type: str
|
||||
agent_name: ClassVar[str]
|
||||
agent_name: str = "Abstact_Agent"
|
||||
logger: AgentLog = AgentLog(agent_name)
|
||||
history: List[AgentHistoryItem] = []
|
||||
action_manager: Optional[ActionManager] = None
|
||||
observation_manager: Optional[ObservationManager] = None
|
||||
reward_function: Optional[RewardFunction] = None
|
||||
agent_settings = Optional[AgentSettings] = None
|
||||
agent_settings: Optional[AgentSettings] = None
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -213,6 +214,11 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for AbstractScriptedAgents."""
|
||||
|
||||
agent_name: str = "Abstract_Scripted_Agent"
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return an action to be taken in the environment."""
|
||||
@@ -227,6 +233,7 @@ class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"):
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
agent_name: str = "Proxy_Agent"
|
||||
agent_settings = Union[AgentSettings | None] = None
|
||||
most_reason_action: ActType
|
||||
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
|
||||
@@ -15,6 +15,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Abstract TAP agents."""
|
||||
|
||||
agent_name: str = "Abstract_TAP"
|
||||
starting_node_name: str
|
||||
next_execution_timestep: int
|
||||
|
||||
@@ -32,3 +33,11 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
|
||||
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
|
||||
)
|
||||
self.config.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.config.action_manager.node_names)
|
||||
starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.starting_node_name = self.config.action_manager.node_names[starting_node_idx]
|
||||
self.config.logger.debug(f"Selected Starting node ID: {self.starting_node_name}")
|
||||
|
||||
@@ -16,6 +16,11 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
"""Configuration Schema for DataManipulationAgent."""
|
||||
|
||||
starting_application_name: str
|
||||
agent_name: str = "Data_Manipulation_Agent"
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Meh."""
|
||||
self.setup_agent()
|
||||
|
||||
@property
|
||||
def next_execution_timestep(self) -> int:
|
||||
@@ -52,10 +57,3 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.config.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
|
||||
|
||||
@@ -18,6 +18,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"
|
||||
class ConfigSchema(pydantic.BaseModel):
|
||||
"""Config schema for Probabilistic agent settings."""
|
||||
|
||||
agent_name: str = "Probabilistic_Agent"
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
|
||||
@@ -10,7 +10,12 @@ from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Random Agents."""
|
||||
|
||||
agent_name = "Random_Agent"
|
||||
|
||||
def get_action(self) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
@@ -31,6 +36,8 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Periodic Agent."""
|
||||
|
||||
agent_name = "Periodic_Agent"
|
||||
"""Name of the agent."""
|
||||
start_step: int = 20
|
||||
"The timestep at which an agent begins performing it's actions."
|
||||
start_variance: int = 5
|
||||
@@ -69,9 +76,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.config.max_executions:
|
||||
self.num_executions += 1
|
||||
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
|
||||
self._set_next_execution_timestep(timestep + self.config.frequency, self.config.variance)
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
|
||||
return "DONOTHING", {}
|
||||
|
||||
@@ -19,20 +19,19 @@ class TAP001(AbstractTAPAgent, identifier="TAP001"):
|
||||
class ConfigSchema(AbstractTAPAgent.ConfigSchema):
|
||||
"""Configuration Schema for TAP001 Agent."""
|
||||
|
||||
agent_name: str = "TAP001"
|
||||
installed: bool = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""___init___ bruv. Restecpa."""
|
||||
super().__init__()
|
||||
self.setup_agent()
|
||||
|
||||
@property
|
||||
def starting_node_name(self) -> str:
|
||||
"""Node that TAP001 starts from."""
|
||||
return self.config.starting_node_name
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> TAP001:
|
||||
"""Override the base from_config method to ensure successful agent setup."""
|
||||
obj: TAP001 = cls(config=cls.ConfigSchema(**config))
|
||||
obj.setup_agent()
|
||||
return obj
|
||||
|
||||
def get_action(self, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
|
||||
@@ -72,11 +71,3 @@ class TAP001(AbstractTAPAgent, identifier="TAP001"):
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.config.action_manager.node_names)
|
||||
starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.starting_node_name = self.config.action_manager.node_names[starting_node_idx]
|
||||
self.config.logger.debug(f"Selected Starting node ID: {self.starting_node_name}")
|
||||
|
||||
@@ -558,7 +558,6 @@ class PrimaiteGame:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
|
||||
@@ -272,6 +272,7 @@ class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Abstract Agent used in tests."""
|
||||
|
||||
agent_name: str = "Controlled_Agent"
|
||||
most_recent_action: Tuple[str, Dict]
|
||||
|
||||
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
|
||||
Reference in New Issue
Block a user