#2869 - Updates to ConfigSchema declaration and addressing some review comments

This commit is contained in:
Charlie Crane
2025-01-10 14:09:15 +00:00
parent 7af9d3724f
commit e3f4775acb
9 changed files with 44 additions and 126 deletions

View File

@@ -1,6 +1,6 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _about:
@@ -15,7 +15,12 @@ Developing Agents for PrimAITE
Agents within PrimAITE, follow the shown inheritance structure below.
The inheritance structure of agents within PrimAITE are shown below. When developing custom agents for use with PrimAITE, please see the relevant documentation for each agent type to determine which is most relevant for your implementation.
All agent types within PrimAITE are listed under the ``_registry`` attribute of the parent class, ``AbstractAgent``.
# TODO: Turn this into an inheritance diagram
# TODO: Would this be necessary?
AbstractAgent
|
@@ -40,12 +45,12 @@ AbstractAgent
#. **ConfigSchema**:
Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*.
Agent generation will fail if incorrect parameters are passed to the ConfigSchema, for the chosen Agent.
Agent generation will fail if incorrect or invalid parameters are passed to the ConfigSchema, of the chosen Agent.
.. code-block:: python
class ExampleAgent(AbstractAgent, identifier = "example_agent"):
class ExampleAgent(AbstractAgent, identifier = "ExampleAgent"):
"""An example agent for demonstration purposes."""
config: "ExampleAgent.ConfigSchema"
@@ -56,10 +61,10 @@ AbstractAgent
class ConfigSchema(AbstractAgent.ConfigSchema):
"""ExampleAgent configuration schema"""
agent_name: str
agent_name: str = "ExampleAgent
"""Name of agent"""
action_interval: int
"""Number of steps between agent actions"""
starting_host: int
"""Host node that this agent should start from in the given environment."""
.. code-block:: YAML
@@ -89,22 +94,24 @@ AbstractAgent
- type: DUMMY
agent_settings:
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
agent_name: "Example Agent"
starting_host: "Server_1"
#. **identifier**:
#. **Identifiers**:
All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent.
All agent classes should have a ``identifier`` attribute, a unique kebab-case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent.
Changes to YAML file
====================
Agent configurations specified within YAML files used for earlier versions of PrimAITE will need updating to be compatible with PrimAITE v4.0.0+.
Agents now follow a more standardised settings definition, so should be more consistent across YAML.
PrimAITE v4.0.0 introduces some breaking changes to how environment configuration yaml files are created. YAML files created for Primaite versions 3.3.0 should be compatible through a translation function, though it is encouraged that these are updated to reflect the updated format of 4.0.0+.
Agents now follow a more standardised settings definition, so should be more consistent across YAML files and the available agent types with PrimAITE.
# TODO: Show changes to YAML config needed here
All configurable items for agents sit under the ``agent_settings`` heading within your YAML files. There is no need for the inclusion of a ``start_settings``.

View File

@@ -115,7 +115,6 @@ class AbstractAgent(BaseModel):
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
print(config)
obj = cls(
config=cls.ConfigSchema(**config["agent_settings"]),
action_manager=ActionManager.from_config(config["game"], config["action_manager"]),

View File

@@ -6,16 +6,17 @@ from abc import abstractmethod
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from pydantic import Field
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = "AbstractTAPAgent"
class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
class AbstractTAPAgent(AbstractScriptedAgent, identifier="AbstractTAP"):
"""Base class for TAP agents to inherit from."""
config: "AbstractTAPAgent.ConfigSchema"
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
agent_name: str = "Abstract_TAP"
next_execution_timestep: int = 0
@@ -45,7 +46,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.config.action_manager.node_names)
num_nodes = len(self.action_manager.node_names)
starting_node_idx = random.randint(0, num_nodes - 1)
self.starting_node_name = self.config.action_manager.node_names[starting_node_idx]
self.logger.debug(f"Selected starting node: {self.starting_node_name}")
self.config.starting_node_name = self.action_manager.node_names[starting_node_idx]
self.logger.debug(f"Selected starting node: {self.config.starting_node_name}")

View File

@@ -2,6 +2,7 @@
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from pydantic import Field
from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent
@@ -11,7 +12,7 @@ __all__ = "DataManipulationAgent"
class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingAgent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
config: "DataManipulationAgent.ConfigSchema"
config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema())
agent_name: str = "Data_Manipulation_Agent"
class ConfigSchema(AbstractTAPAgent.ConfigSchema):

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Tuple
import numpy as np
import pydantic
from gymnasium.core import ObsType
from pydantic import Field
from primaite.game.agent.interface import AbstractScriptedAgent
@@ -14,7 +15,7 @@ __all__ = "ProbabilisticAgent"
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
config: "ProbabilisticAgent.ConfigSchema"
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
@@ -22,7 +23,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
agent_name: str = "ProbabilisticAgent"
action_probabilities: Dict[int, float]
action_probabilities: Dict[int, float] = None
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
@pydantic.field_validator("action_probabilities", mode="after")

View File

@@ -3,15 +3,18 @@ import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from pydantic import Field
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = ("RandomAgent", "PeriodicAgent")
class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"):
"""Agent that ignores its observation and acts completely at random."""
config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema())
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Random Agents."""
@@ -30,10 +33,10 @@ class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
return self.action_manager.get_action(self.action_manager.space.sample())
class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
config: "PeriodicAgent.ConfigSchema" = {}
config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema())
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Periodic Agent."""
@@ -45,25 +48,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
"Maximum number of times the agent can execute its action."
num_executions: int = 0
"""Number of times the agent has executed an action."""
# TODO: Also in abstract_tap - move up and inherit? Add to AgentStartSettings?
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""
@property
def start_step(self) -> int:
"""Return the timestep at which an agent begins performing it's actions."""
return self.config.agent_settings.start_settings.start_step
@property
def start_variance(self) -> int:
"""Returns the deviation around the start step."""
return self.config.agent_settings.start_settings.variance
@property
def frequency(self) -> int:
"""Returns the number of timesteps to wait between performing actions."""
return self.config.agent_settings.start_settings.frequency
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -79,8 +66,8 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(timestep + self.frequency, self.start_variance)
self._set_next_execution_timestep(timestep + self.config.frequency, self.config.variance)
self.target_node = self.action_manager.node_names[0]
return "node_application_execute", {"node_name": self.target_node, "application_name": 0}
return "DONOTHING", {}
return "do_nothing", {}

View File

@@ -1,78 +0,0 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class TAP001(AbstractScriptedAgent):
"""
TAP001 | Mobile Malware -- Ransomware Variant.
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
next_execution_timestep: int = 0
starting_node_idx: int = 0
installed: bool = False
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute the ransomware application.
This application acts a wrapper around the kill-chain, similar to green-analyst and
the previous UC2 data manipulation bot.
:param obs: Current observation for this agent.
:type obs: ObsType
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
if not self.installed:
self.installed = True
return "NODE_APPLICATION_INSTALL", {
"node_id": self.starting_node_idx,
"application_name": "RansomwareScript",
}
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
for n, act in self.action_manager.action_map.items():
if not act[0] == "NODE_APPLICATION_INSTALL":
continue
if act[1]["node_id"] == self.starting_node_idx:
self.ip_address = act[1]["ip_address"]
return
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)

View File

@@ -444,7 +444,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},

View File

@@ -63,8 +63,8 @@ agents:
frequency: 4
variance: 3
action_probabilities:
0: 0.6
1: 0.4
0: 0.4
1: 0.6