#2869 - Starter changes in refactor of agent classes for refactor to become extensible. Identifiers added to classes and beginning of the inclusion of a ConfigSchema to base AbstractAgentClass
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
"""Interface for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -69,7 +71,7 @@ class AgentSettings(BaseModel):
|
||||
"""Settings for configuring the operation of an agent."""
|
||||
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
"Configuration for when an agent begins performing it's actions."
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = False
|
||||
@@ -90,38 +92,78 @@ class AgentSettings(BaseModel):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
action_manager: Optional[ActionManager]
|
||||
observation_manager: Optional[ObservationManager]
|
||||
reward_function: Optional[RewardFunction]
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""
|
||||
Configuration Schema for AbstractAgents.
|
||||
|
||||
:param type: Type of agent being generated.
|
||||
:type type: str
|
||||
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
:type agent_name: Optional[str]
|
||||
:param action_space: Action space for the agent.
|
||||
:type action_space: Optional[ActionManager]
|
||||
:type agent_name: str
|
||||
:param observation_space: Observation space for the agent.
|
||||
:type observation_space: Optional[ObservationSpace]
|
||||
:param reward_function: Reward function for the agent.
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
:param agent_settings: Configurable Options for Abstracted Agents
|
||||
:param agent_settings: Configurable Options for Abstracted Agents.
|
||||
:type agent_settings: Optional[AgentSettings]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.history: List[AgentHistoryItem] = []
|
||||
self.logger = AgentLog(agent_name)
|
||||
|
||||
type: str
|
||||
agent_name: ClassVar[str]
|
||||
agent_settings = Optional[AgentSettings] = None
|
||||
history: List[AgentHistoryItem] = []
|
||||
logger: AgentLog = AgentLog(agent_name)
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# agent_name: Optional[str],
|
||||
# action_space: Optional[ActionManager],
|
||||
# observation_space: Optional[ObservationManager],
|
||||
# reward_function: Optional[RewardFunction],
|
||||
# agent_settings: Optional[AgentSettings] = None,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Initialize an agent.
|
||||
|
||||
# :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
# :type agent_name: Optional[str]
|
||||
# :param action_space: Action space for the agent.
|
||||
# :type action_space: Optional[ActionManager]
|
||||
# :param observation_space: Observation space for the agent.
|
||||
# :type observation_space: Optional[ObservationSpace]
|
||||
# :param reward_function: Reward function for the agent.
|
||||
# :type reward_function: Optional[RewardFunction]
|
||||
# :param agent_settings: Configurable Options for Abstracted Agents
|
||||
# :type agent_settings: Optional[AgentSettings]
|
||||
# """
|
||||
# self.agent_name: str = agent_name or "unnamed_agent"
|
||||
# self.action_manager: Optional[ActionManager] = action_space
|
||||
# self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
# self.reward_function: Optional[RewardFunction] = reward_function
|
||||
# self.agent_settings = agent_settings or AgentSettings()
|
||||
# self.history: List[AgentHistoryItem] = []
|
||||
# self.logger = AgentLog(agent_name)
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
"""Creates an agent component from a configuration dictionary."""
|
||||
return cls(config=cls.ConfigSchema(**config))
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -130,7 +172,7 @@ class AbstractAgent(ABC):
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_manager.update(state)
|
||||
return self.config.observation_manager.update(state)
|
||||
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
@@ -141,7 +183,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
return self.config.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -165,14 +207,14 @@ class AbstractAgent(ABC):
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
request = self.config.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.history.append(
|
||||
self.config.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
@@ -180,10 +222,10 @@ class AbstractAgent(ABC):
|
||||
|
||||
def save_reward_to_history(self) -> None:
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
self.config.history[-1].reward = self.config.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -192,26 +234,34 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
agent_settings = Union[AgentSettings | None] = None
|
||||
most_reason_action: ActType
|
||||
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# agent_name: Optional[str],
|
||||
# action_space: Optional[ActionManager],
|
||||
# observation_space: Optional[ObservationManager],
|
||||
# reward_function: Optional[RewardFunction],
|
||||
# agent_settings: Optional[AgentSettings] = None,
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# agent_name=agent_name,
|
||||
# action_space=action_space,
|
||||
# observation_space=observation_space,
|
||||
# reward_function=reward_function,
|
||||
# )
|
||||
# self.most_recent_action: ActType
|
||||
# self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
# self.action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
@@ -224,7 +274,7 @@ class ProxyAgent(AbstractAgent):
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.most_recent_action)
|
||||
return self.config.action_manager.get_action(self.most_recent_action)
|
||||
|
||||
def store_action(self, action: ActType):
|
||||
"""
|
||||
@@ -232,4 +282,4 @@ class ProxyAgent(AbstractAgent):
|
||||
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
self.config.most_recent_action = action
|
||||
|
||||
@@ -7,12 +7,22 @@ from gymnasium.core import ObsType
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation_Agent"):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
|
||||
config: "DataManipulationAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for DataManipulationAgent."""
|
||||
|
||||
# TODO: Could be worth moving this to a "AbstractTAPAgent"
|
||||
starting_node_name: str
|
||||
starting_application_name: str
|
||||
next_execution_timestep: int
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
@@ -38,12 +48,15 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
self.logger.debug(msg="Performing do NOTHING")
|
||||
return "DONOTHING", {}
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
return "node_application_execute", {
|
||||
"node_name": self.config.starting_node_name,
|
||||
"application_name": self.config.starting_application_name,
|
||||
}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
|
||||
@@ -12,7 +12,7 @@ from primaite.game.agent.observations.observation_manager import ObservationMana
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
class Settings(pydantic.BaseModel):
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite.game.agent.observations.observation_manager import ObservationMana
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -27,7 +27,7 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class PeriodicAgent(AbstractScriptedAgent):
|
||||
class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
|
||||
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
|
||||
|
||||
class Settings(BaseModel):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
@@ -7,20 +9,27 @@ from gymnasium.core import ObsType
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class TAP001(AbstractScriptedAgent):
|
||||
class TAP001(AbstractScriptedAgent, identifier="TAP001"):
|
||||
"""
|
||||
TAP001 | Mobile Malware -- Ransomware Variant.
|
||||
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
# TODO: Link with DataManipulationAgent via a parent "TAP" agent class.
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
installed: bool = False
|
||||
config: "TAP001.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for TAP001 Agent."""
|
||||
|
||||
starting_node_name: str
|
||||
next_execution_timestep: int = 0
|
||||
installed: bool = False
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.setup_agent()
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
@@ -28,9 +37,9 @@ class TAP001(AbstractScriptedAgent):
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
self.config.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
@@ -45,28 +54,28 @@ class TAP001(AbstractScriptedAgent):
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {}
|
||||
if timestep < self.config.next_execution_timestep:
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
|
||||
|
||||
if not self.installed:
|
||||
self.installed = True
|
||||
return "NODE_APPLICATION_INSTALL", {
|
||||
"node_id": self.starting_node_idx,
|
||||
if not self.config.installed:
|
||||
self.config.installed = True
|
||||
return "node_application_install", {
|
||||
"node_name": self.config.starting_node_name,
|
||||
"application_name": "RansomwareScript",
|
||||
}
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
return "node_application_execute", {"node_name": self.config.starting_node_name, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
for n, act in self.action_manager.action_map.items():
|
||||
if not act[0] == "NODE_APPLICATION_INSTALL":
|
||||
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
|
||||
for n, act in self.config.action_manager.action_map.items():
|
||||
if not act[0] == "node_application_install":
|
||||
continue
|
||||
if act[1]["node_id"] == self.starting_node_idx:
|
||||
if act[1]["node_name"] == self.config.starting_node_name:
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
@@ -74,5 +83,7 @@ class TAP001(AbstractScriptedAgent):
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
num_nodes = len(self.config.action_manager.node_names)
|
||||
# TODO: Change this to something?
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.logger.debug(f"Selected Starting node ID: {self.starting_node_idx}")
|
||||
|
||||
@@ -547,6 +547,26 @@ class PrimaiteGame:
|
||||
reward_function = RewardFunction.from_config(reward_function_cfg)
|
||||
|
||||
# CREATE AGENT
|
||||
|
||||
# TODO: MAKE THIS BIT WORK AND NOT THE IF/ELSE CHAIN OF HORRORS
|
||||
|
||||
# Pass through:
|
||||
# config
|
||||
# action manager
|
||||
# observation_manager
|
||||
# reward_function
|
||||
|
||||
new_agent_cfg = {
|
||||
"action_manager": action_space,
|
||||
"agent_name": agent_cfg["ref"],
|
||||
"observation_manager": obs_space,
|
||||
"agent_settings": agent_cfg.get("agent_settings", {}),
|
||||
"reward_function": reward_function,
|
||||
}
|
||||
new_agent_cfg = agent_cfg["settings"]
|
||||
# new_agent_cfg.update{}
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=new_agent_cfg)
|
||||
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
settings = agent_cfg.get("agent_settings", {})
|
||||
|
||||
@@ -264,7 +264,7 @@ def example_network() -> Network:
|
||||
return network
|
||||
|
||||
|
||||
class ControlledAgent(AbstractAgent):
|
||||
class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
|
||||
"""Agent that can be controlled by the tests."""
|
||||
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user