#2869 - New Year, new changes. Actioning review comments and some changes following self-review and catchup

This commit is contained in:
Charlie Crane
2025-01-02 17:41:24 +00:00
parent dc6f2be209
commit 2108b914e3
28 changed files with 130 additions and 144 deletions

View File

@@ -1,10 +1,10 @@
repos:
- repo: local
hooks:
- id: ensure-copyright-clause
name: ensure copyright clause
entry: python copyright_clause_pre_commit_hook.py
language: python
# - repo: local
# hooks:
# - id: ensure-copyright-clause
# name: ensure copyright clause
# entry: python copyright_clause_pre_commit_hook.py
# language: python
- repo: http://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:

View File

@@ -7,13 +7,13 @@
Extensible Agents
*****************
Agents defined within PrimAITE have been updated to allow for easier creation of new bespoke agents.
Agents defined within PrimAITE have been updated to allow for easier creation of new bespoke agents.
Developing Agents for PrimAITE
==============================
Agents within PrimAITE, follow the shown inheritance structure, and
Agents within PrimAITE, follow the shown inheritance structure below.
# TODO: Turn this into an inheritance diagram
@@ -32,7 +32,6 @@ AbstractAgent
| |
| | - RandomAgent
|
|
| - ProxyAgent
|
| - ControlledAgent
@@ -41,6 +40,8 @@ AbstractAgent
#. **ConfigSchema**:
Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*.
Agent generation will fail if incorrect parameters are passed to the ConfigSchema, for the chosen Agent.
.. code-block:: python
@@ -49,7 +50,7 @@ AbstractAgent
config: "ExampleAgent.ConfigSchema"
"""Agent configuration"""
num_executions: int
num_executions: int = 0
"""Number of action executions by agent"""
class ConfigSchema(AbstractAgent.ConfigSchema):
@@ -60,9 +61,43 @@ AbstractAgent
action_interval: int
"""Number of steps between agent actions"""
.. code-block:: YAML
- ref: example_green_agent
team: GREEN
type: ExampleAgent
observation_space: null
action_space:
action_list:
- type: do_nothing
action_map:
0:
action: do_nothing
options: {}
options:
nodes:
- node_name: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 25
frequency: 20
variance: 5
#. **identifier**:
All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry.
All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent.
Changes to YAML file
====================

View File

@@ -17,7 +17,7 @@ from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
__all__ = ("AgentHistoryItem", "AgentStartSettings", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
__all__ = ("AgentHistoryItem", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
class AgentHistoryItem(BaseModel):
@@ -43,63 +43,18 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AgentStartSettings":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions."
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
"""Construct agent settings from a config dictionary.
:param config: A dict of options for the agent settings.
:type config: Dict
:return: The agent settings.
:rtype: AgentSettings
"""
if config is None:
return cls()
return cls(**config)
class AbstractAgent(BaseModel):
"""Base class for scripted and RL agents."""
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
_logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
config: "AbstractAgent.ConfigSchema"
history: List[AgentHistoryItem] = []
action_manager: ActionManager
observation_manager: ObservationManager
reward_function: RewardFunction
class ConfigSchema(BaseModel):
"""
@@ -118,13 +73,34 @@ class AbstractAgent(BaseModel):
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
agent_name: ClassVar[str] = "Abstract_Agent" # TODO: Make this a ClassVar[str] like verb in actions?
history: List[AgentHistoryItem] = []
_logger: AgentLog = AgentLog(agent_name=agent_name)
action_manager: ActionManager
observation_manager: ObservationManager
reward_function: RewardFunction
agent_settings: Optional[AgentSettings] = None
agent_name: str = "Abstract_Agent"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
if identifier in cls._registry:
@@ -132,35 +108,11 @@ class AbstractAgent(BaseModel):
cls._registry[identifier] = cls
super().__init_subclass__(**kwargs)
@property
def logger(self) -> AgentLog:
"""Return the AgentLog."""
return self.config._logger
@property
def flatten_obs(self) -> bool:
"""Return agent flatten_obs param."""
return self.config.agent_settings.flatten_obs
@property
def history(self) -> List[AgentHistoryItem]:
"""Return the agent history."""
return self.config.history
@property
def observation_manager(self) -> ObservationManager:
"""Returns the agents observation manager."""
return self.config.observation_manager
@property
def action_manager(self) -> ActionManager:
"""Returns the agents action manager."""
return self.config.action_manager
@property
def reward_function(self) -> RewardFunction:
"""Returns the agents reward function."""
return self.config.reward_function
return self.config.flatten_obs
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
@@ -217,7 +169,7 @@ class AbstractAgent(BaseModel):
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.config.history.append(
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
@@ -225,7 +177,7 @@ class AbstractAgent(BaseModel):
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.config.history[-1].reward = self.reward_function.current_reward
self.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):

View File

@@ -36,7 +36,7 @@ from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]

View File

@@ -1,9 +1,9 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent import interface
from primaite.game.agent.scripted_agents import (
abstract_tap,
data_manipulation_bot,
interface,
probabilistic_agent,
random_agent,
)

View File

@@ -7,7 +7,7 @@ from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = "AbstractTAPAgent"
@@ -50,4 +50,4 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
num_nodes = len(self.config.action_manager.node_names)
starting_node_idx = random.randint(0, num_nodes - 1)
self.starting_node_name = self.config.action_manager.node_names[starting_node_idx]
self.logger.debug(f"Selected Starting node ID: {self.starting_node_name}")
self.logger.debug(f"Selected starting node: {self.starting_node_name}")

View File

@@ -6,7 +6,7 @@ import numpy as np
import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings
from primaite.game.agent.interface import AbstractScriptedAgent, AgentSettings
__all__ = "ProbabilisticAgent"

View File

@@ -4,7 +4,7 @@ from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = ("RandomAgent", "PeriodicAgent")
@@ -37,23 +37,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Periodic Agent."""
agent_name: str = "Periodic_Agent"
"""Name of the agent."""
agent_name: str = "Periodic_Agent"
"""Name of the agent."""
# TODO: This is available in config.agent_settings.start_settings.start_step
start_step: int = 20
"The timestep at which an agent begins performing it's actions."
start_variance: int = 5
"Deviation around the start step."
# TODO: This is available in config.agent_settings.start_settings.frequency
frequency: int = 5
"The number of timesteps to wait between performing actions."
# TODO: This is available in config.agent_settings.start_settings.variance
variance: int = 0
"The amount the frequency can randomly change to."
max_executions: int = 999999
"Maximum number of times the agent can execute its action."
num_executions: int = 0
@@ -62,6 +48,22 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""
@property
def start_step(self) -> int:
"""Return the timestep at which an agent begins performing it's actions."""
return self.config.agent_settings.start_settings.start_step
@property
def start_variance(self) -> int:
"""Returns the deviation around the start step."""
return self.config.agent_settings.start_settings.variance
@property
def frequency(self) -> int:
"""Returns the number of timesteps to wait between performing actions."""
return self.config.agent_settings.start_settings.frequency
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -75,9 +77,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.config.max_executions:
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(timestep + self.frequency, self.variance)
self._set_next_execution_timestep(timestep + self.frequency, self.start_variance)
self.target_node = self.action_manager.node_names[0]
return "node_application_execute", {"node_name": self.target_node, "application_name": 0}

View File

@@ -10,7 +10,7 @@ from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
@@ -555,9 +555,6 @@ class PrimaiteGame:
# new_agent_cfg.update{}
if agent_type in AbstractAgent._registry:
print(agent_type)
print(agent_config)
print(AbstractAgent._registry)
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents
if agent_type == "ProxyAgent":

View File

@@ -10,7 +10,7 @@ import numpy as np
from gymnasium.core import ActType, ObsType
from primaite import getLogger
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO

View File

@@ -7,7 +7,7 @@ from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler

View File

@@ -9,7 +9,7 @@ from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.interface import AbstractAgent
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system import FileSystem

View File

@@ -7,7 +7,7 @@ import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import UserManager

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -7,7 +7,7 @@ import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -5,7 +5,7 @@ import pytest
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv

View File

@@ -17,7 +17,7 @@ from typing import Tuple
import pytest
import yaml
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus

View File

@@ -3,7 +3,7 @@ import pytest
import yaml
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv

View File

@@ -5,7 +5,7 @@ from typing import Tuple
import pytest
import yaml
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.container import Network

View File

@@ -5,7 +5,7 @@ from primaite.game.agent.rewards import (
WebpageUnavailablePenalty,
WebServer404Penalty,
)
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.interface.request import RequestResponse

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer