#2869 - Updates to Probabilistic Agent to follow the defined extensibility schema.
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import logging
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
|
||||
"""Base class for TAP agents to inherit from."""
|
||||
|
||||
config: "AbstractTAPAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Abstract TAP agents."""
|
||||
|
||||
agent_name: str = "Abstract_TAP"
|
||||
starting_node_name: str
|
||||
next_execution_timestep: int
|
||||
|
||||
@abstractmethod
|
||||
def setup_agent(self) -> None:
|
||||
"""Set up agent."""
|
||||
pass
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
|
||||
)
|
||||
self.config.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.config.action_manager.node_names)
|
||||
starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.starting_node_name = self.config.action_manager.node_names[starting_node_idx]
|
||||
self.config.logger.debug(f"Selected Starting node ID: {self.starting_node_name}")
|
||||
@@ -1,5 +1,4 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
@@ -8,23 +8,20 @@ from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
class ConfigSchema(pydantic.BaseModel):
|
||||
"""Config schema for Probabilistic agent settings."""
|
||||
config: "ProbabilisticAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
agent_name: str = "Probabilistic_Agent"
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
"""Strict validation."""
|
||||
action_space: ActionManager
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
@@ -45,32 +42,16 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"
|
||||
)
|
||||
return v
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
settings: Dict = {},
|
||||
) -> None:
|
||||
# If the action probabilities are not specified, create equal probabilities for all actions
|
||||
if "action_probabilities" not in settings:
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.ConfigSchema(**settings)
|
||||
def __init__(self) -> None:
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
|
||||
@property
|
||||
def probabilities(self) -> Dict[str, int]:
|
||||
"""Convenience method to view the probabilities of the Agent."""
|
||||
return np.asarray(list(self.config.action_probabilities.values()))
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Sample the action space randomly.
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent
|
||||
|
||||
|
||||
class TAP001(AbstractTAPAgent, identifier="TAP001"):
|
||||
"""
|
||||
TAP001 | Mobile Malware -- Ransomware Variant.
|
||||
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
config: "TAP001.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractTAPAgent.ConfigSchema):
|
||||
"""Configuration Schema for TAP001 Agent."""
|
||||
|
||||
agent_name: str = "TAP001"
|
||||
installed: bool = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""___init___ bruv. Restecpa."""
|
||||
super().__init__()
|
||||
self.setup_agent()
|
||||
|
||||
@property
|
||||
def starting_node_name(self) -> str:
|
||||
"""Node that TAP001 starts from."""
|
||||
return self.config.starting_node_name
|
||||
|
||||
def get_action(self, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
|
||||
This application acts a wrapper around the kill-chain, similar to green-analyst and
|
||||
the previous UC2 data manipulation bot.
|
||||
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.config.next_execution_timestep:
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
|
||||
|
||||
if not self.config.installed:
|
||||
self.config.installed = True
|
||||
return "node_application_install", {
|
||||
"node_name": self.starting_node_name,
|
||||
"application_name": "RansomwareScript",
|
||||
}
|
||||
|
||||
return "node_application_execute", {
|
||||
"node_name": self.starting_node_name,
|
||||
"application_name": "RansomwareScript",
|
||||
}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
|
||||
for n, act in self.config.action_manager.action_map.items():
|
||||
if not act[0] == "node_application_install":
|
||||
continue
|
||||
if act[1]["node_name"] == self.starting_node_name:
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
Reference in New Issue
Block a user