Merge remote-tracking branch 'origin/dev' into feature/1847-more-red-agents

This commit is contained in:
Czar Echavez
2024-06-11 14:26:11 +01:00
66 changed files with 8645 additions and 1012 deletions

View File

@@ -1 +1 @@
3.0.0b9
3.0.0

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
pass
class AgentActionHistoryItem(BaseModel):
class AgentHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
timestep: int
@@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel):
response: RequestResponse
"""The response sent back by the simulator for this action."""
reward: Optional[float] = None
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -110,7 +112,7 @@ class AbstractAgent(ABC):
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.action_history: List[AgentActionHistoryItem] = []
self.history: List[AgentHistoryItem] = []
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -130,7 +132,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
return self.reward_function.update(state=state, last_action_response=self.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -161,12 +163,16 @@ class AbstractAgent(ABC):
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.action_history.append(
AgentActionHistoryItem(
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
)
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""

View File

@@ -34,7 +34,7 @@ from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
@@ -44,7 +44,7 @@ class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -64,7 +64,7 @@ class AbstractReward:
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
file_name,
]
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward):
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
@@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
@@ -343,7 +343,7 @@ class SharedReward(AbstractReward):
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback(self.agent_name)
@@ -389,7 +389,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.

View File

@@ -160,6 +160,7 @@ class PrimaiteGame:
agent = self.agents[agent_name]
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.save_reward_to_history()
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward
@@ -359,11 +360,6 @@ class PrimaiteGame:
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
ransomware_encrypt_p_of_success=float(
opt.get("ransomware_encrypt_p_of_success", "0.1")
),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:

View File

@@ -4,13 +4,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Customising red agents\n",
"# Customising UC2 Red Agents\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n",
"\n",
"First, let's load the standard Data Manipulation config file, and see what the red agent does.\n",
"\n",
"*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*"
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
]
},
{
@@ -22,7 +24,7 @@
"# Imports\n",
"\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.game.agent.interface import AgentHistoryItem\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml\n",
"from pprint import pprint"
@@ -63,7 +65,7 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",

View File

@@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Manipulation Scenario\n"
"# Data Manipulation Scenario\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
]
},
{
@@ -59,7 +61,7 @@
"\n",
"At the start of every episode, the red agent randomly chooses either client 1 or client 2 to login to. It waits a bit then sends a DELETE query to the database from its chosen client. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n",
"\n",
"[<img src=\"_package_data/uc2_attack.png\" width=\"500\"/>](_package_data/uc2_attack.png)\n",
"![uc2_attack](./_package_data/uc2_attack.png)\n",
"\n",
"_(click image to enlarge)_"
]
@@ -79,7 +81,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reinforcement learning details"
"## Reinforcement learning details"
]
},
{
@@ -180,15 +182,15 @@
"| link_id | endpoint_a | endpoint_b |\n",
"|---------|------------------|-------------------|\n",
"| 1 | router_1 | switch_1 |\n",
"| 1 | router_1 | switch_2 |\n",
"| 1 | switch_1 | domain_controller |\n",
"| 1 | switch_1 | web_server |\n",
"| 1 | switch_1 | database_server |\n",
"| 1 | switch_1 | backup_server |\n",
"| 1 | switch_1 | security_suite |\n",
"| 1 | switch_2 | client_1 |\n",
"| 1 | switch_2 | client_2 |\n",
"| 1 | switch_2 | security_suite |\n",
"| 2 | router_1 | switch_2 |\n",
"| 3 | switch_1 | domain_controller |\n",
"| 4 | switch_1 | web_server |\n",
"| 5 | switch_1 | database_server |\n",
"| 6 | switch_1 | backup_server |\n",
"| 7 | switch_1 | security_suite |\n",
"| 8 | switch_2 | client_1 |\n",
"| 9 | switch_2 | client_2 |\n",
"| 10 | switch_2 | security_suite |\n",
"\n",
"\n",
"The ACL rules in the observation space appear in the same order that they do in the actual ACL. Though, only the first 10 rules are shown, there are default rules lower down that cannot be changed by the agent. The extra rules just allow the network to function normally, by allowing pings, ARP traffic, etc.\n",
@@ -392,7 +394,7 @@
"# Imports\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.game.agent.interface import AgentHistoryItem\n",
"import yaml\n",
"from pprint import pprint\n"
]
@@ -401,7 +403,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Instantiate the environment. We also disable the agent observation flattening.\n",
"Instantiate the environment. \n",
"We will also disable the agent observation flattening.\n",
"\n",
"This cell will print the observation when the network is healthy. You should be able to verify Node file and service statuses against the description above."
]
@@ -444,7 +447,7 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
@@ -691,7 +694,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -705,7 +708,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.8"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting information out of PrimAITE\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
"import yaml\n",
"from primaite import PRIMAITE_CONFIG\n",
"\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from notebook.services.config import ConfigManager\n",
"\n",
"cm = ConfigManager().update('notebook', {'limit_output': 50}) # limit output lines to 50 - for neatness\n",
"\n",
"# create the env\n",
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"env = PrimaiteGymEnv(env_config=cfg)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualising the Simulation Network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The network can be visualised by running the code below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.game.simulation.network.draw()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting the state of a simulation object\n",
"\n",
"The state of the simulation object is used to determine the observation space used by agents.\n",
"\n",
"Any object created using the ``SimComponent`` class has a ``describe_state`` method which can show the state of the object.\n",
"\n",
"An example of such an object is ``Computer`` which inherits from ``SimComponent``. In the default network configuration, ``client_1`` is a Computer object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
"client_1.describe_state()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### More specific describe_state\n",
"\n",
"As you can see, the output from the ``describe_state`` method for the ``Computer`` object includes the describe state for all its components. This can cause a large describe state output.\n",
"\n",
"As stated, the ``describe_state`` can be called on any object that inherits ``SimComponent``. This can allow you retrieve the state of a specific item."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client_1.file_system.describe_state()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## System Logs\n",
"\n",
"Objects that inherit from the ``Node`` class will inherit the ``sys_log`` attribute.\n",
"\n",
"This is to simulate the idea that items such as Computer, Routers, Servers, etc. have a logging system used to diagnose problems."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# store config\n",
"# this is to prevent the notebook from breaking your local settings\n",
"was_enabled = PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"]\n",
"was_syslogs_enabled = PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"]\n",
"\n",
"# enable dev mode so that the default config outputs are overridden for this demo\n",
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = True\n",
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = True\n",
"\n",
"\n",
"\n",
"\n",
"# Remake the environment\n",
"env = PrimaiteGymEnv(env_config=cfg)\n",
"\n",
"# get the example computer\n",
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
"\n",
"# show sys logs on terminal\n",
"client_1.sys_log.show()\n",
"\n",
"\n",
"\n",
"\n",
"# restore config\n",
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n",
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -6,7 +6,9 @@
"source": [
"# Requests and Responses\n",
"\n",
"Agents interact with the PrimAITE simulation via the Request system.\n"
"Agents interact with the PrimAITE simulation via the Request system.\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
]
},
{

View File

@@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Multi agent system using RLLIB\n",
"# Train a Multi agent system using RLLIB\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
]
@@ -25,13 +27,13 @@
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
@@ -60,8 +62,8 @@
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
]

View File

@@ -4,7 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Single agent system using RLLib\n",
"# Train a Single agent system using RLLib\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},
@@ -18,8 +21,7 @@
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from ray.rllib.algorithms import ppo\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
@@ -52,8 +54,8 @@
"\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")\n"
]
@@ -74,7 +76,7 @@
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128}\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
@@ -97,7 +99,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -6,6 +6,8 @@
"source": [
"# Training an SB3 Agent\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
]
},
@@ -43,7 +45,10 @@
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)"
" cfg = yaml.safe_load(f)\n",
"for agent in cfg['agents']:\n",
" if agent['ref'] == 'defender':\n",
" agent['agent_settings']['flatten_obs']=True"
]
},
{
@@ -177,7 +182,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.8"
}
},
"nbformat": 4,

View File

@@ -6,6 +6,8 @@
"source": [
"# Using Episode Schedules\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
"\n",
@@ -13,50 +15,6 @@
"directory with several config files that work together."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining variations in the config file.\n",
"\n",
"### Base scenario\n",
"The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are \n",
"populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that\n",
"remain fixed for the entire training/evaluation session.\n",
"\n",
"The placeholders are defined as YAML Aliases and they are denoted by an asterisk (`*placeholder`).\n",
"\n",
"### Variations\n",
"For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.\n",
"\n",
"The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand (`&anchor`).\n",
"\n",
"[Learn more about YAML Aliases and Anchors here.](https://www.educative.io/blog/advanced-yaml-syntax-cheatsheet#:~:text=YAML%20Anchors%20and%20Alias)\n",
"\n",
"### Schedule\n",
"Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a\n",
"YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.\n",
"\n",
"It takes the following format:\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```yaml\n",
"base_scenario: base.yaml\n",
"schedule:\n",
" 0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)\n",
" - laydown_1.yaml\n",
" - attack_1.yaml\n",
" 1: # list of variations to load in at episode 1 (after the first env.reset() call)\n",
" - laydown_2.yaml\n",
" - attack_2.yaml\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -298,8 +256,8 @@
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_A'].action_history[i].action\n",
" red_action = env.game.agents['red_A'].action_history[i].action\n",
" green_action = env.game.agents['green_A'].history[i].action\n",
" red_action = env.game.agents['red_A'].history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]
@@ -329,8 +287,8 @@
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_B'].action_history[i].action\n",
" red_action = env.game.agents['red_B'].action_history[i].action\n",
" green_action = env.game.agents['green_B'].history[i].action\n",
" red_action = env.game.agents['red_B'].history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]

View File

@@ -4,8 +4,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple multi-processing demo using SubprocVecEnv from SB3\n",
"Based on a code example provided by Rachael Proctor."
"# Simple multi-processing demonstration\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook uses SubprocVecEnv from SB3."
]
},
{
@@ -140,7 +143,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite import getLogger
from primaite.game.agent.interface import ProxyAgent
@@ -12,6 +11,7 @@ from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
@@ -37,6 +37,8 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
@property
def agent(self) -> ProxyAgent:
@@ -61,7 +63,7 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(step, action, state, reward)
@@ -83,16 +85,19 @@ class PrimaiteGymEnv(gymnasium.Env):
with open(path, "w") as file:
json.dump(data, file)
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
_LOGGER.info(
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
PacketCapture.clear()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
@@ -126,166 +131,5 @@ class PrimaiteGymEnv(gymnasium.Env):
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
def close(self):
"""Close the simulation."""
self.env.close()
@property
def game(self) -> PrimaiteGame:
"""Pass through game from env."""
return self.env.game
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict
"""
self.episode_counter: int = 0
"""Current episode number."""
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
def step(
self, actions: Dict[str, ActType]
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
:type actions: Dict[str, ActType]
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
step = self.game.step_counter
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.pre_timestep()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
# 3. Get next observations
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(step, actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": step,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)

View File

@@ -87,7 +87,7 @@ class PrimaiteIO:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number

View File

@@ -0,0 +1,177 @@
import json
from typing import Dict, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.system.core.packet_capture import PacketCapture
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict
"""
self.episode_counter: int = 0
"""Current episode number."""
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
self._obs_space_in_preferred_format = True
self._action_space_in_preferred_format = True
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
PacketCapture.clear()
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
def step(
self, actions: Dict[str, ActType]
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
:type actions: Dict[str, ActType]
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
step = self.game.step_counter
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.pre_timestep()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
# 3. Get next observations
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(step, actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": step,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
def close(self):
"""Close the simulation."""
self.env.close()
@property
def game(self) -> PrimaiteGame:
"""Pass through game from env."""
return self.env.game

View File

@@ -6,7 +6,9 @@
"source": [
"# Build a simulation using the Python API\n",
"\n",
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects."
]
},
{

View File

@@ -7,6 +7,8 @@
"source": [
"# PrimAITE Router Simulation Demo\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"This demo uses a modified version of the ARCD Use Case 2 Network (seen below) to demonstrate the capabilities of the Network simulator in PrimAITE."
]
},

View File

@@ -1,8 +1,6 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -11,43 +9,10 @@ from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
class RansomwareAttackStage(IntEnum):
"""
Enumeration representing different attack stages of the ransomware script.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
in the simulation.
Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
DOWNLOAD = 1
"Installing the Encryption Script - Testing"
INSTALL = 2
"The stage where logon procedures are simulated."
ACTIVATE = 3
"Operating Status Changes"
PROPAGATE = 4
"Represents the stage of performing a horizontal port scan on the target."
COMMAND_AND_CONTROL = 5
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
PAYLOAD = 6
"Stage of actively attacking the target."
SUCCEEDED = 7
"Indicates the attack has been successfully completed."
FAILED = 8
"Signifies that the attack has failed."
class RansomwareScript(Application):
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
:ivar payload: The attack stage query payload. (Default Corrupt)
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:ivar repeat: Whether to repeat attacking once finished.
:ivar payload: The attack stage query payload. (Default ENCRYPT)
"""
server_ip_address: Optional[IPv4Address] = None
@@ -56,16 +21,6 @@ class RansomwareScript(Application):
"""Password required to access the database."""
payload: Optional[str] = "ENCRYPT"
"Payload String for the payload stage"
target_scan_p_of_success: float = 0.9
"Probability of the target scan succeeding: Default 0.9"
c2_beacon_p_of_success: float = 0.9
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
ransomware_encrypt_p_of_success: float = 0.9
"Probability of the ransomware attack succeeding: Default 0.9"
repeat: bool = False
"If true, the Denial of Service bot will keep performing the attack."
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
"The ransomware attack stage. See RansomwareAttackStage Class"
def __init__(self, **kwargs):
kwargs["name"] = "RansomwareScript"
@@ -90,7 +45,7 @@ class RansomwareScript(Application):
@property
def _host_db_client(self) -> DatabaseClient:
"""Return the database client that is installed on the same machine as the Ransomware Script."""
db_client = self.software_manager.software.get("DatabaseClient")
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
if db_client is None:
self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.")
return db_client
@@ -108,16 +63,6 @@ class RansomwareScript(Application):
)
return rm
def _activate(self):
"""
Simulate the install process as the initial stage of the attack.
Advances the attack stage to 'ACTIVATE' attack state.
"""
if self.attack_stage == RansomwareAttackStage.INSTALL:
self.sys_log.info(f"{self.name}: Activated!")
self.attack_stage = RansomwareAttackStage.ACTIVATE
def run(self) -> bool:
"""Calls the parent classes execute method before starting the application loop."""
super().run()
@@ -133,20 +78,9 @@ class RansomwareScript(Application):
return False
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self.attack_stage = RansomwareAttackStage.NOT_STARTED
self._local_download()
self._install()
self._activate()
self._perform_target_scan()
self._setup_beacon()
self._perform_ransomware_encrypt()
if self.repeat and self.attack_stage in (
RansomwareAttackStage.SUCCEEDED,
RansomwareAttackStage.FAILED,
):
self.attack_stage = RansomwareAttackStage.NOT_STARTED
return True
if self._perform_ransomware_encrypt():
return True
return False
else:
self.sys_log.warning(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
return False
@@ -156,10 +90,6 @@ class RansomwareScript(Application):
server_ip_address: IPv4Address,
server_password: Optional[str] = None,
payload: Optional[str] = None,
target_scan_p_of_success: Optional[float] = None,
c2_beacon_p_of_success: Optional[float] = None,
ransomware_encrypt_p_of_success: Optional[float] = None,
repeat: bool = True,
):
"""
Configure the Ransomware Script to communicate with a DatabaseService.
@@ -167,10 +97,6 @@ class RansomwareScript(Application):
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The attack stage query (Encrypt / Delete)
:param target_scan_p_of_success: The probability of success for the target scan stage.
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:param repeat: Whether to repeat attacking once finished.
"""
if server_ip_address:
self.server_ip_address = server_ip_address
@@ -178,74 +104,15 @@ class RansomwareScript(Application):
self.server_password = server_password
if payload:
self.payload = payload
if target_scan_p_of_success:
self.target_scan_p_of_success = target_scan_p_of_success
if c2_beacon_p_of_success:
self.c2_beacon_p_of_success = c2_beacon_p_of_success
if ransomware_encrypt_p_of_success:
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
if repeat:
self.repeat = repeat
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
f"{repeat=}."
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
)
def _install(self):
"""
Simulate the install stage in the kill-chain.
Advances the attack stage to 'ACTIVATE' if successful.
From this attack stage onwards.
the ransomware application is now visible from this point onwardin the observation space.
"""
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
downloads_folder = self.file_system.get_folder(folder_name="downloads")
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
ransomware_file.num_access += 1
self.attack_stage = RansomwareAttackStage.INSTALL
def _setup_beacon(self):
"""
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
by default the success rate is 0.5
"""
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
if simulate_trial(self.c2_beacon_p_of_success):
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
if c2c_setup:
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
def _perform_target_scan(self):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PROPAGATE` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
# perform a port scan to identify that the SQL port is open on the server
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
if simulate_trial(self.target_scan_p_of_success):
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
if port_is_open:
self.attack_stage = RansomwareAttackStage.PROPAGATE
def attack(self) -> bool:
"""Perform the attack steps after opening the application."""
self.run()
if not self._can_perform_action():
self.sys_log.warning("Ransomware application is unable to perform it's actions.")
self.run()
self.num_executions += 1
return self._application_loop()
@@ -254,57 +121,30 @@ class RansomwareScript(Application):
self._db_connection = self._host_db_client.get_new_connection()
return True if self._db_connection else False
def _perform_ransomware_encrypt(self):
def _perform_ransomware_encrypt(self) -> bool:
"""
Execute the Ransomware Encrypt payload on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
"""
if self._host_db_client is None:
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
self.attack_stage = RansomwareAttackStage.FAILED
return
return False
self._host_db_client.server_ip_address = self.server_ip_address
self._host_db_client.server_password = self.server_password
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
if simulate_trial(self.ransomware_encrypt_p_of_success):
self.sys_log.info(f"{self.name}: Attempting to launch payload")
if not self._db_connection:
self._establish_db_connection()
if self._db_connection:
attack_successful = self._db_connection.query(self.payload)
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
if attack_successful:
self.sys_log.info(f"{self.name}: Payload Successful")
self.attack_stage = RansomwareAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Payload failed")
self.attack_stage = RansomwareAttackStage.FAILED
self.sys_log.info(f"{self.name}: Attempting to launch payload")
if not self._db_connection:
self._establish_db_connection()
if self._db_connection:
attack_successful = self._db_connection.query(self.payload)
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
if attack_successful:
self.sys_log.info(f"{self.name}: Payload Successful")
return True
else:
self.sys_log.info(f"{self.name}: Payload failed")
return False
else:
self.sys_log.warning("Attack Attempted to launch too quickly")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download(self):
"""Downloads itself via the onto the local file_system."""
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
if self._local_download_verify():
self.attack_stage = RansomwareAttackStage.DOWNLOAD
else:
self.sys_log.info("Malware failed to create a installation location")
self.attack_stage = RansomwareAttackStage.FAILED
else:
self.sys_log.info("Malware failed to download")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download_verify(self) -> bool:
"""Verifies a download location - Creates one if needed."""
for folder in self.file_system.folders:
if self.file_system.folders[folder].name == "downloads":
self.file_system.num_file_creations += 1
return True
self.file_system.create_folder("downloads")
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
return True
return False

View File

@@ -21,6 +21,8 @@ class PacketCapture:
The PCAPs are logged to: <simulation output directory>/<hostname>/<hostname>_<ip address>_pcap.log
"""
_logger_instances: List[logging.Logger] = []
def __init__(
self,
hostname: str,
@@ -65,10 +67,12 @@ class PacketCapture:
if outbound:
self.outbound_logger = logging.getLogger(self._get_logger_name(outbound))
PacketCapture._logger_instances.append(self.outbound_logger)
logger = self.outbound_logger
else:
self.inbound_logger = logging.getLogger(self._get_logger_name(outbound))
logger = self.inbound_logger
PacketCapture._logger_instances.append(self.inbound_logger)
logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs
logger.addHandler(file_handler)
@@ -122,3 +126,13 @@ class PacketCapture:
if SIM_OUTPUT.save_pcap_logs:
msg = frame.model_dump_json()
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
@staticmethod
def clear():
"""Close all open PCAP file handlers."""
for logger in PacketCapture._logger_instances:
handlers = logger.handlers[:]
for handler in handlers:
logger.removeHandler(handler)
handler.close()
PacketCapture._logger_instances = []

View File

@@ -2,7 +2,7 @@
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union
@@ -11,16 +11,16 @@ from typing import Any, Dict, Tuple, Union
import polars as pl
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
def total_rewards_dict(total_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:param total_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode csv as a dict.
"""
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
df_dict = pl.read_csv(total_rewards_csv_file).to_dict()
return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}

View File

@@ -2,7 +2,7 @@
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union
@@ -26,9 +26,9 @@ class SessionOutputWriter:
Is used to write session outputs to csv file.
"""
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
_TOTAL_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
"Episode",
"Average Reward",
"Total Reward",
]
def __init__(
@@ -43,7 +43,7 @@ class SessionOutputWriter:
:param env: PrimAITE gym environment.
:type env: Primaite
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
If `false` it will output the average reward per episode, defaults to False
If `false` it will output the total reward per episode, defaults to False
:type transaction_writer: bool, optional
:param learning_session: Set to `true` to indicate that the current session is a training session. This
determines the name of the folder which contains the final output csv. Defaults to True
@@ -56,7 +56,7 @@ class SessionOutputWriter:
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
fn = f"total_reward_per_episode_{self._env.timestamp_str}.csv"
self._csv_file_path: "Path"
if self.learning_session:
@@ -94,7 +94,7 @@ class SessionOutputWriter:
if isinstance(data, Transaction):
header, data = data.as_csv_data()
else:
header = self._AV_REWARD_PER_EPISODE_HEADER
header = self._TOTAL_REWARD_PER_EPISODE_HEADER
if self._first_write:
self._init_csv_writer()