#2869 - New Year, new changes. Actioning review comments and some changes following self-review and catchup

This commit is contained in:
Charlie Crane
2025-01-02 17:41:24 +00:00
parent dc6f2be209
commit 2108b914e3
28 changed files with 130 additions and 144 deletions

View File

@@ -1,10 +1,10 @@
repos:
- repo: local
hooks:
- id: ensure-copyright-clause
name: ensure copyright clause
entry: python copyright_clause_pre_commit_hook.py
language: python
# - repo: local
# hooks:
# - id: ensure-copyright-clause
# name: ensure copyright clause
# entry: python copyright_clause_pre_commit_hook.py
# language: python
- repo: http://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:

View File

@@ -13,7 +13,7 @@ Agents defined within PrimAITE have been updated to allow for easier creation of
Developing Agents for PrimAITE
==============================
Agents within PrimAITE, follow the shown inheritance structure, and
Agents within PrimAITE, follow the shown inheritance structure below.
# TODO: Turn this into an inheritance diagram
@@ -32,7 +32,6 @@ AbstractAgent
| |
| | - RandomAgent
|
|
| - ProxyAgent
|
| - ControlledAgent
@@ -41,6 +40,8 @@ AbstractAgent
#. **ConfigSchema**:
Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*.
Agent generation will fail if incorrect parameters are passed to the ConfigSchema, for the chosen Agent.
.. code-block:: python
@@ -49,7 +50,7 @@ AbstractAgent
config: "ExampleAgent.ConfigSchema"
"""Agent configuration"""
num_executions: int
num_executions: int = 0
"""Number of action executions by agent"""
class ConfigSchema(AbstractAgent.ConfigSchema):
@@ -60,9 +61,43 @@ AbstractAgent
action_interval: int
"""Number of steps between agent actions"""
.. code-block:: YAML
- ref: example_green_agent
team: GREEN
type: ExampleAgent
observation_space: null
action_space:
action_list:
- type: do_nothing
action_map:
0:
action: do_nothing
options: {}
options:
nodes:
- node_name: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 25
frequency: 20
variance: 5
#. **identifier**:
All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry.
All agent classes should have a ``identifier`` attribute, a unique snake_case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent.
Changes to YAML file
====================

View File

@@ -17,7 +17,7 @@ from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
__all__ = ("AgentHistoryItem", "AgentStartSettings", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
__all__ = ("AgentHistoryItem", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
class AgentHistoryItem(BaseModel):
@@ -43,63 +43,18 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AgentStartSettings":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions."
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
"""Construct agent settings from a config dictionary.
:param config: A dict of options for the agent settings.
:type config: Dict
:return: The agent settings.
:rtype: AgentSettings
"""
if config is None:
return cls()
return cls(**config)
class AbstractAgent(BaseModel):
"""Base class for scripted and RL agents."""
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
_logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
config: "AbstractAgent.ConfigSchema"
history: List[AgentHistoryItem] = []
action_manager: ActionManager
observation_manager: ObservationManager
reward_function: RewardFunction
class ConfigSchema(BaseModel):
"""
@@ -118,13 +73,34 @@ class AbstractAgent(BaseModel):
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
agent_name: ClassVar[str] = "Abstract_Agent" # TODO: Make this a ClassVar[str] like verb in actions?
history: List[AgentHistoryItem] = []
_logger: AgentLog = AgentLog(agent_name=agent_name)
action_manager: ActionManager
observation_manager: ObservationManager
reward_function: RewardFunction
agent_settings: Optional[AgentSettings] = None
agent_name: str = "Abstract_Agent"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
if identifier in cls._registry:
@@ -132,35 +108,11 @@ class AbstractAgent(BaseModel):
cls._registry[identifier] = cls
super().__init_subclass__(**kwargs)
@property
def logger(self) -> AgentLog:
"""Return the AgentLog."""
return self.config._logger
@property
def flatten_obs(self) -> bool:
"""Return agent flatten_obs param."""
return self.config.agent_settings.flatten_obs
@property
def history(self) -> List[AgentHistoryItem]:
"""Return the agent history."""
return self.config.history
@property
def observation_manager(self) -> ObservationManager:
"""Returns the agents observation manager."""
return self.config.observation_manager
@property
def action_manager(self) -> ActionManager:
"""Returns the agents action manager."""
return self.config.action_manager
@property
def reward_function(self) -> RewardFunction:
"""Returns the agents reward function."""
return self.config.reward_function
return self.config.flatten_obs
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
@@ -217,7 +169,7 @@ class AbstractAgent(BaseModel):
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.config.history.append(
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
@@ -225,7 +177,7 @@ class AbstractAgent(BaseModel):
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.config.history[-1].reward = self.reward_function.current_reward
self.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):

View File

@@ -36,7 +36,7 @@ from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]

View File

@@ -1,9 +1,9 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent import interface
from primaite.game.agent.scripted_agents import (
abstract_tap,
data_manipulation_bot,
interface,
probabilistic_agent,
random_agent,
)

View File

@@ -7,7 +7,7 @@ from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = "AbstractTAPAgent"
@@ -50,4 +50,4 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
num_nodes = len(self.config.action_manager.node_names)
starting_node_idx = random.randint(0, num_nodes - 1)
self.starting_node_name = self.config.action_manager.node_names[starting_node_idx]
self.logger.debug(f"Selected Starting node ID: {self.starting_node_name}")
self.logger.debug(f"Selected starting node: {self.starting_node_name}")

View File

@@ -6,7 +6,7 @@ import numpy as np
import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings
from primaite.game.agent.interface import AbstractScriptedAgent, AgentSettings
__all__ = "ProbabilisticAgent"

View File

@@ -4,7 +4,7 @@ from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = ("RandomAgent", "PeriodicAgent")
@@ -37,23 +37,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Periodic Agent."""
agent_name: str = "Periodic_Agent"
"""Name of the agent."""
# TODO: This is available in config.agent_settings.start_settings.start_step
start_step: int = 20
"The timestep at which an agent begins performing it's actions."
start_variance: int = 5
"Deviation around the start step."
# TODO: This is available in config.agent_settings.start_settings.frequency
frequency: int = 5
"The number of timesteps to wait between performing actions."
# TODO: This is available in config.agent_settings.start_settings.variance
variance: int = 0
"The amount the frequency can randomly change to."
max_executions: int = 999999
"Maximum number of times the agent can execute its action."
num_executions: int = 0
@@ -62,6 +48,22 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""
@property
def start_step(self) -> int:
"""Return the timestep at which an agent begins performing it's actions."""
return self.config.agent_settings.start_settings.start_step
@property
def start_variance(self) -> int:
"""Returns the deviation around the start step."""
return self.config.agent_settings.start_settings.variance
@property
def frequency(self) -> int:
"""Returns the number of timesteps to wait between performing actions."""
return self.config.agent_settings.start_settings.frequency
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -75,9 +77,9 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.config.max_executions:
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(timestep + self.frequency, self.variance)
self._set_next_execution_timestep(timestep + self.frequency, self.start_variance)
self.target_node = self.action_manager.node_names[0]
return "node_application_execute", {"node_name": self.target_node, "application_name": 0}

View File

@@ -10,7 +10,7 @@ from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
@@ -555,9 +555,6 @@ class PrimaiteGame:
# new_agent_cfg.update{}
if agent_type in AbstractAgent._registry:
print(agent_type)
print(agent_config)
print(AbstractAgent._registry)
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents
if agent_type == "ProxyAgent":

View File

@@ -10,7 +10,7 @@ import numpy as np
from gymnasium.core import ActType, ObsType
from primaite import getLogger
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO

View File

@@ -7,7 +7,7 @@ from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler

View File

@@ -9,7 +9,7 @@ from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.interface import AbstractAgent
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system import FileSystem

View File

@@ -7,7 +7,7 @@ import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import UserManager

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -7,7 +7,7 @@ import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -5,7 +5,7 @@ import pytest
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv

View File

@@ -17,7 +17,7 @@ from typing import Tuple
import pytest
import yaml
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus

View File

@@ -3,7 +3,7 @@ import pytest
import yaml
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv

View File

@@ -5,7 +5,7 @@ from typing import Tuple
import pytest
import yaml
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.container import Network

View File

@@ -5,7 +5,7 @@ from primaite.game.agent.rewards import (
WebpageUnavailablePenalty,
WebServer404Penalty,
)
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.interface.request import RequestResponse

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer