#2869 - Starter changes in refactor of agent classes for refactor to become extensible. Identifiers added to classes and beginning of the inclusion of a ConfigSchema to base AbstractAgentClass

This commit is contained in:
Charlie Crane
2024-11-20 17:19:35 +00:00
parent 6844bd692a
commit a3dc616126
7 changed files with 174 additions and 80 deletions

View File

@@ -1,7 +1,9 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""Interface for agents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
@@ -69,7 +71,7 @@ class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions"
"Configuration for when an agent begins performing it's actions."
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
@@ -90,38 +92,78 @@ class AgentSettings(BaseModel):
return cls(**config)
class AbstractAgent(ABC):
class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
"""Base class for scripted and RL agents."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings] = None,
) -> None:
"""
Initialize an agent.
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
config: "AbstractAgent.ConfigSchema"
action_manager: Optional[ActionManager]
observation_manager: Optional[ObservationManager]
reward_function: Optional[RewardFunction]
class ConfigSchema(BaseModel):
"""
Configuration Schema for AbstractAgents.
:param type: Type of agent being generated.
:type type: str
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
:type agent_name: Optional[str]
:param action_space: Action space for the agent.
:type action_space: Optional[ActionManager]
:type agent_name: str
:param observation_space: Observation space for the agent.
:type observation_space: Optional[ObservationSpace]
:param reward_function: Reward function for the agent.
:type reward_function: Optional[RewardFunction]
:param agent_settings: Configurable Options for Abstracted Agents
:param agent_settings: Configurable Options for Abstracted Agents.
:type agent_settings: Optional[AgentSettings]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.history: List[AgentHistoryItem] = []
self.logger = AgentLog(agent_name)
type: str
agent_name: ClassVar[str]
agent_settings = Optional[AgentSettings] = None
history: List[AgentHistoryItem] = []
logger: AgentLog = AgentLog(agent_name)
# def __init__(
# self,
# agent_name: Optional[str],
# action_space: Optional[ActionManager],
# observation_space: Optional[ObservationManager],
# reward_function: Optional[RewardFunction],
# agent_settings: Optional[AgentSettings] = None,
# ) -> None:
# """
# Initialize an agent.
# :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
# :type agent_name: Optional[str]
# :param action_space: Action space for the agent.
# :type action_space: Optional[ActionManager]
# :param observation_space: Observation space for the agent.
# :type observation_space: Optional[ObservationSpace]
# :param reward_function: Reward function for the agent.
# :type reward_function: Optional[RewardFunction]
# :param agent_settings: Configurable Options for Abstracted Agents
# :type agent_settings: Optional[AgentSettings]
# """
# self.agent_name: str = agent_name or "unnamed_agent"
# self.action_manager: Optional[ActionManager] = action_space
# self.observation_manager: Optional[ObservationManager] = observation_space
# self.reward_function: Optional[RewardFunction] = reward_function
# self.agent_settings = agent_settings or AgentSettings()
# self.history: List[AgentHistoryItem] = []
# self.logger = AgentLog(agent_name)
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
return cls(config=cls.ConfigSchema(**config))
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -130,7 +172,7 @@ class AbstractAgent(ABC):
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_manager.update(state)
return self.config.observation_manager.update(state)
def update_reward(self, state: Dict) -> float:
"""
@@ -141,7 +183,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state=state, last_action_response=self.history[-1])
return self.config.reward_function.update(state=state, last_action_response=self.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -165,14 +207,14 @@ class AbstractAgent(ABC):
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_manager.form_request(action_identifier=action, action_options=options)
request = self.config.action_manager.form_request(action_identifier=action, action_options=options)
return request
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.history.append(
self.config.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
@@ -180,10 +222,10 @@ class AbstractAgent(ABC):
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.history[-1].reward = self.reward_function.current_reward
self.config.history[-1].reward = self.config.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent):
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
"""Base class for actors which generate their own behaviour."""
@abstractmethod
@@ -192,26 +234,34 @@ class AbstractScriptedAgent(AbstractAgent):
return super().get_action(obs=obs, timestep=timestep)
class ProxyAgent(AbstractAgent):
class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"):
"""Agent that sends observations to an RL model and receives actions from that model."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings] = None,
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Proxy Agent."""
agent_settings = Union[AgentSettings | None] = None
most_reason_action: ActType
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
action_masking: bool = agent_settings.action_masking if agent_settings else False
# def __init__(
# self,
# agent_name: Optional[str],
# action_space: Optional[ActionManager],
# observation_space: Optional[ObservationManager],
# reward_function: Optional[RewardFunction],
# agent_settings: Optional[AgentSettings] = None,
# ) -> None:
# super().__init__(
# agent_name=agent_name,
# action_space=action_space,
# observation_space=observation_space,
# reward_function=reward_function,
# )
# self.most_recent_action: ActType
# self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
# self.action_masking: bool = agent_settings.action_masking if agent_settings else False
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
@@ -224,7 +274,7 @@ class ProxyAgent(AbstractAgent):
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.most_recent_action)
return self.config.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""
@@ -232,4 +282,4 @@ class ProxyAgent(AbstractAgent):
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action
self.config.most_recent_action = action

View File

@@ -7,12 +7,22 @@ from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class DataManipulationAgent(AbstractScriptedAgent):
class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation_Agent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
next_execution_timestep: int = 0
starting_node_idx: int = 0
config: "DataManipulationAgent.ConfigSchema"
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for DataManipulationAgent."""
# TODO: Could be worth moving this to a "AbstractTAPAgent"
starting_node_name: str
starting_application_name: str
next_execution_timestep: int
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
@@ -38,12 +48,15 @@ class DataManipulationAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
self.logger.debug(msg="Performing do NOTHING")
return "DONOTHING", {}
self.logger.debug(msg="Performing do nothing action")
return "do_nothing", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
self.logger.info(msg="Performing a data manipulation attack!")
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
return "node_application_execute", {
"node_name": self.config.starting_node_name,
"application_name": self.config.starting_application_name,
}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""

View File

@@ -12,7 +12,7 @@ from primaite.game.agent.observations.observation_manager import ObservationMana
from primaite.game.agent.rewards import RewardFunction
class ProbabilisticAgent(AbstractScriptedAgent):
class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
class Settings(pydantic.BaseModel):

View File

@@ -11,7 +11,7 @@ from primaite.game.agent.observations.observation_manager import ObservationMana
from primaite.game.agent.rewards import RewardFunction
class RandomAgent(AbstractScriptedAgent):
class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -27,7 +27,7 @@ class RandomAgent(AbstractScriptedAgent):
return self.action_manager.get_action(self.action_manager.space.sample())
class PeriodicAgent(AbstractScriptedAgent):
class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
class Settings(BaseModel):

View File

@@ -1,4 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
import random
from typing import Dict, Tuple
@@ -7,20 +9,27 @@ from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class TAP001(AbstractScriptedAgent):
class TAP001(AbstractScriptedAgent, identifier="TAP001"):
"""
TAP001 | Mobile Malware -- Ransomware Variant.
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
# TODO: Link with DataManipulationAgent via a parent "TAP" agent class.
next_execution_timestep: int = 0
starting_node_idx: int = 0
installed: bool = False
config: "TAP001.ConfigSchema"
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for TAP001 Agent."""
starting_node_name: str
next_execution_timestep: int = 0
installed: bool = False
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.setup_agent()
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -28,9 +37,9 @@ class TAP001(AbstractScriptedAgent):
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
self.config.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute the ransomware application.
@@ -45,28 +54,28 @@ class TAP001(AbstractScriptedAgent):
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {}
if timestep < self.config.next_execution_timestep:
return "do_nothing", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
if not self.installed:
self.installed = True
return "NODE_APPLICATION_INSTALL", {
"node_id": self.starting_node_idx,
if not self.config.installed:
self.config.installed = True
return "node_application_install", {
"node_name": self.config.starting_node_name,
"application_name": "RansomwareScript",
}
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
return "node_application_execute", {"node_name": self.config.starting_node_name, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
for n, act in self.action_manager.action_map.items():
if not act[0] == "NODE_APPLICATION_INSTALL":
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
for n, act in self.config.action_manager.action_map.items():
if not act[0] == "node_application_install":
continue
if act[1]["node_id"] == self.starting_node_idx:
if act[1]["node_name"] == self.config.starting_node_name:
self.ip_address = act[1]["ip_address"]
return
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
@@ -74,5 +83,7 @@ class TAP001(AbstractScriptedAgent):
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
num_nodes = len(self.config.action_manager.node_names)
# TODO: Change this to something?
self.starting_node_idx = random.randint(0, num_nodes - 1)
self.logger.debug(f"Selected Starting node ID: {self.starting_node_idx}")

View File

@@ -547,6 +547,26 @@ class PrimaiteGame:
reward_function = RewardFunction.from_config(reward_function_cfg)
# CREATE AGENT
# TODO: MAKE THIS BIT WORK AND NOT THE IF/ELSE CHAIN OF HORRORS
# Pass through:
# config
# action manager
# observation_manager
# reward_function
new_agent_cfg = {
"action_manager": action_space,
"agent_name": agent_cfg["ref"],
"observation_manager": obs_space,
"agent_settings": agent_cfg.get("agent_settings", {}),
"reward_function": reward_function,
}
new_agent_cfg = agent_cfg["settings"]
# new_agent_cfg.update{}
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=new_agent_cfg)
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
settings = agent_cfg.get("agent_settings", {})

View File

@@ -264,7 +264,7 @@ def example_network() -> Network:
return network
class ControlledAgent(AbstractAgent):
class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
"""Agent that can be controlled by the tests."""
def __init__(