Merged PR 260: Beta 6 Fixes
Please run some of these changes locally. Read the notebook, check that it makes sense, and run the code cells to see if they produce the result you expect. ## Summary Apologies that all these fixes are part of 1 massive PR instead of individual PRs. I thought it was going to be a quick job and it spiralled out of control. Changes: - Fixed a bug where ACL rules were not resetting on episode reset. - Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses - Fixed a bug where deleted files and folders did not reset correctly on episode reset. - Fixed a bug where service health status was using the actual health state instead of the visible health state - Fixed a bug where the database file health status was using the incorrect value for negative rewards - Fixed a bug preventing file actions from reaching their intended file - Made database patch correctly take 2 timesteps instead of being immediate - Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING - Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon. - Fix a bug where ACLs were not showing up correctly in the observation space. - Added a recap of agent actions to the `info` output of `step()` - Added a notebook which explains UC2, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function. ## Test process New notebook verifies end-to-end UC2 functionality. ## Checklist - [y] PR is linked to a **work item** - [y] **acceptance criteria** of linked ticket are met - [y] performed **self-review** of the code - [~] written **tests** for any new functionality added with this PR - [y] updated the **documentation** if this PR changes or adds functionality - [n] written/updated **design docs** if this PR implements new functionality - [y] updated the **change log** - [y] ran **pre-commit** checks for code style - [y] attended to any **TO-DOs** left in the code Related work items: #2208, #2218, #2219, #2220
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -156,3 +156,4 @@ benchmark/output
|
||||
# src/primaite/notebooks/scratch.ipynb
|
||||
src/primaite/notebooks/scratch.py
|
||||
sandbox.py
|
||||
sandbox/
|
||||
|
||||
11
CHANGELOG.md
11
CHANGELOG.md
@@ -6,6 +6,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Fixed a bug where ACL rules were not resetting on episode reset.
|
||||
- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses
|
||||
- Fixed a bug where deleted files and folders did not reset correctly on episode reset.
|
||||
- Fixed a bug where service health status was using the actual health state instead of the visible health state
|
||||
- Fixed a bug where the database file health status was using the incorrect value for negative rewards
|
||||
- Fixed a bug preventing file actions from reaching their intended file
|
||||
- Made database patch correctly take 2 timesteps instead of being immediate
|
||||
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
|
||||
- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon.
|
||||
- Fix a bug where ACLs were not showing up correctly in the observation space.
|
||||
- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.
|
||||
- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config.
|
||||
- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config.
|
||||
- Fixed an issue where the data manipulation attack was triggered at episode start.
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.0.0b5
|
||||
3.0.0b6dev
|
||||
|
||||
@@ -31,7 +31,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -304,63 +304,63 @@ agents:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 5
|
||||
22:
|
||||
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
24: # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
25: # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
@@ -504,6 +504,24 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
|
||||
@@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb: str = "delete"
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["do_nothing"]
|
||||
# return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid]
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
"""Action which repairs a file."""
|
||||
@@ -460,13 +470,13 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id)
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
return [
|
||||
@@ -914,6 +924,15 @@ class ActionManager:
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
ip_address_order = cfg["options"].pop("ip_address_order", {})
|
||||
ip_address_list = []
|
||||
for entry in ip_address_order:
|
||||
node_ref = entry["node_ref"]
|
||||
nic_num = entry["nic_num"]
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_ref)
|
||||
ip_address = node_obj.ethernet_port[nic_num].ip_address
|
||||
ip_address_list.append(ip_address)
|
||||
|
||||
obj = cls(
|
||||
game=game,
|
||||
actions=cfg["action_list"],
|
||||
@@ -921,7 +940,7 @@ class ActionManager:
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=None,
|
||||
ip_address_list=ip_address_list or None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
@@ -648,10 +649,13 @@ class AclObservation(AbstractObservation):
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
@@ -660,15 +664,26 @@ class AclObservation(AbstractObservation):
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -110,6 +110,13 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
:type state: Dict
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.info(
|
||||
f"Could not calculate {self.__class__} reward because "
|
||||
"simulation state did not contain enough information."
|
||||
)
|
||||
return 0.0
|
||||
|
||||
health_status = database_file_state["health_status"]
|
||||
if health_status == 2:
|
||||
return -1
|
||||
|
||||
@@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
@@ -115,7 +113,7 @@ class PrimaiteGame:
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
agent_actions = self.apply_agent_actions() # noqa
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
@@ -133,12 +131,15 @@ class PrimaiteGame:
|
||||
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
agent_actions = {}
|
||||
for agent in self.agents:
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
return agent_actions
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
@@ -227,31 +228,7 @@ class PrimaiteGame:
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
elif n_type == "router":
|
||||
new_node = Router(
|
||||
hostname=node_cfg["hostname"],
|
||||
num_ports=node_cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in node_cfg:
|
||||
for port_num, port_cfg in node_cfg["ports"].items():
|
||||
new_node.configure_port(
|
||||
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
|
||||
)
|
||||
# new_node.enable_port(port_num)
|
||||
if "acl" in node_cfg:
|
||||
for r_num, r_cfg in node_cfg["acl"].items():
|
||||
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
|
||||
# this: 'r_cfg.get('src_port')'
|
||||
# Port/IPProtocol. TODO Refactor
|
||||
new_node.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("ip_address"),
|
||||
dst_ip_address=r_cfg.get("ip_address"),
|
||||
position=r_num,
|
||||
)
|
||||
new_node = Router.from_config(node_cfg)
|
||||
else:
|
||||
_LOGGER.warning(f"invalid node type {n_type} in config")
|
||||
if "services" in node_cfg:
|
||||
|
||||
BIN
src/primaite/notebooks/_package_data/uc2_attack.png
Normal file
BIN
src/primaite/notebooks/_package_data/uc2_attack.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
BIN
src/primaite/notebooks/_package_data/uc2_network.png
Normal file
BIN
src/primaite/notebooks/_package_data/uc2_network.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
@@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
@@ -172,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -186,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {}
|
||||
infos = {"agent_actions": agent_actions}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
|
||||
@@ -23,7 +23,6 @@ class FileSystem(SimComponent):
|
||||
"List containing all the folders in the file system."
|
||||
deleted_folders: Dict[str, Folder] = {}
|
||||
"List containing all the folders that have been deleted."
|
||||
_folders_by_name: Dict[str, Folder] = {}
|
||||
sys_log: SysLog
|
||||
"Instance of SysLog used to create system logs."
|
||||
sim_root: Path
|
||||
@@ -56,7 +55,6 @@ class FileSystem(SimComponent):
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
@@ -67,7 +65,6 @@ class FileSystem(SimComponent):
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
@@ -173,7 +170,6 @@ class FileSystem(SimComponent):
|
||||
folder = Folder(name=folder_name, sys_log=self.sys_log)
|
||||
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
self._folder_request_manager.add_request(
|
||||
name=folder.uuid, request_type=RequestType(func=folder._request_manager)
|
||||
)
|
||||
@@ -188,14 +184,13 @@ class FileSystem(SimComponent):
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
@@ -221,7 +216,10 @@ class FileSystem(SimComponent):
|
||||
:param folder_name: The folder name.
|
||||
:return: The matching Folder.
|
||||
"""
|
||||
return self._folders_by_name.get(folder_name)
|
||||
for folder in self.folders.values():
|
||||
if folder.name == folder_name:
|
||||
return folder
|
||||
return None
|
||||
|
||||
def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]:
|
||||
"""
|
||||
@@ -261,13 +259,13 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name:
|
||||
# check if file with name already exists
|
||||
folder = self._folders_by_name.get(folder_name)
|
||||
folder = self.get_folder(folder_name)
|
||||
# If not then create it
|
||||
if not folder:
|
||||
folder = self.create_folder(folder_name)
|
||||
else:
|
||||
# Use root folder if folder_name not supplied
|
||||
folder = self._folders_by_name["root"]
|
||||
folder = self.get_folder("root")
|
||||
|
||||
# Create the file and add it to the folder
|
||||
file = File(
|
||||
@@ -474,7 +472,6 @@ class FileSystem(SimComponent):
|
||||
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
if folder.deleted:
|
||||
self.deleted_folders.pop(folder.uuid)
|
||||
|
||||
@@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
|
||||
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -17,8 +17,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
files: Dict[str, File] = {}
|
||||
"Files stored in the folder."
|
||||
_files_by_name: Dict[str, File] = {}
|
||||
"Files by their name as <file name>.<file type>."
|
||||
deleted_files: Dict[str, File] = {}
|
||||
"Files that have been deleted."
|
||||
|
||||
@@ -78,7 +76,6 @@ class Folder(FileSystemItemABC):
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
@@ -89,7 +86,6 @@ class Folder(FileSystemItemABC):
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
@@ -219,7 +215,10 @@ class Folder(FileSystemItemABC):
|
||||
:return: The matching File.
|
||||
"""
|
||||
# TODO: Increment read count?
|
||||
return self._files_by_name.get(file_name)
|
||||
for file in self.files.values():
|
||||
if file.name == file_name:
|
||||
return file
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File:
|
||||
"""
|
||||
@@ -250,15 +249,14 @@ class Folder(FileSystemItemABC):
|
||||
raise Exception(f"Invalid file: {file}")
|
||||
|
||||
# check if file with id or name already exists in folder
|
||||
if (force is not True) and file.name in self._files_by_name:
|
||||
if self.get_file(file.name) is not None and not force:
|
||||
raise Exception(f"File with name {file.name} already exists in folder")
|
||||
|
||||
if (force is not True) and file.uuid in self.files:
|
||||
if (file.uuid in self.files) and not force:
|
||||
raise Exception(f"File with uuid {file.uuid} already exists in folder")
|
||||
|
||||
# add to list
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
|
||||
file.folder = self
|
||||
|
||||
@@ -275,11 +273,10 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
if self.files.get(file.uuid):
|
||||
self.files.pop(file.uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
self.deleted_files[file.uuid] = file
|
||||
file.delete()
|
||||
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
|
||||
self._file_request_manager.remove_request(file.uuid)
|
||||
# self._file_request_manager.remove_request(file.uuid)
|
||||
else:
|
||||
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
|
||||
|
||||
@@ -300,7 +297,6 @@ class Folder(FileSystemItemABC):
|
||||
self.deleted_files[file_id] = file
|
||||
|
||||
self.files = {}
|
||||
self._files_by_name = {}
|
||||
|
||||
def restore_file(self, file_uuid: str):
|
||||
"""
|
||||
@@ -316,7 +312,6 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file_uuid)
|
||||
|
||||
@@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
@@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
class ACLAction(Enum):
|
||||
"""Enum for defining the ACL action types."""
|
||||
|
||||
DENY = 0
|
||||
PERMIT = 1
|
||||
DENY = 2
|
||||
|
||||
|
||||
class ACLRule(SimComponent):
|
||||
@@ -65,11 +66,11 @@ class ACLRule(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -89,6 +90,8 @@ class AccessControlList(SimComponent):
|
||||
implicit_rule: ACLRule
|
||||
max_acl_rules: int = 25
|
||||
_acl: List[Optional[ACLRule]] = [None] * 24
|
||||
_default_config: Dict[int, dict] = {}
|
||||
"""Config dict describing how the ACL list should look at episode start"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
if not kwargs.get("implicit_action"):
|
||||
@@ -106,10 +109,40 @@ class AccessControlList(SimComponent):
|
||||
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
|
||||
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
|
||||
|
||||
for i, rule in enumerate(self._acl):
|
||||
if not rule:
|
||||
continue
|
||||
self._default_config[i] = {"action": rule.action.name}
|
||||
if rule.src_ip_address:
|
||||
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
|
||||
if rule.dst_ip_address:
|
||||
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
|
||||
if rule.src_port:
|
||||
self._default_config[i]["src_port"] = rule.src_port.name
|
||||
if rule.dst_port:
|
||||
self._default_config[i]["dst_port"] = rule.dst_port.name
|
||||
if rule.protocol:
|
||||
self._default_config[i]["protocol"] = rule.protocol.name
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
self.implicit_rule.reset_component_for_episode(episode)
|
||||
super().reset_component_for_episode(episode)
|
||||
self._reset_rules_to_default()
|
||||
|
||||
def _reset_rules_to_default(self) -> None:
|
||||
"""Clear all ACL rules and set them to the default rules config."""
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
for r_num, r_cfg in self._default_config.items():
|
||||
self.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -391,7 +424,6 @@ class RouteTable(SimComponent):
|
||||
sys_log: SysLog
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
self._original_state["routes_orig"] = self.routes
|
||||
@@ -716,8 +748,8 @@ class Router(Node):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
state["num_ports"] = self.num_ports
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
@@ -864,3 +896,63 @@ class Router(Node):
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "Router":
|
||||
"""Create a router based on a config dict.
|
||||
|
||||
Schema:
|
||||
- hostname (str): unique name for this router.
|
||||
- num_ports (int, optional): Number of network ports on the router. 8 by default
|
||||
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
|
||||
ip_address and subnet_mask assigned to that ports (as strings)
|
||||
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
|
||||
where the rule will be added (lower number is resolved first). The values should describe valid ACL
|
||||
Rules as:
|
||||
- action (str): either PERMIT or DENY
|
||||
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
|
||||
Example config:
|
||||
```
|
||||
{
|
||||
'hostname': 'router_1',
|
||||
'num_ports': 5,
|
||||
'ports': {
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
:param cfg: Router config adhering to schema described in main docstring body
|
||||
:type cfg: dict
|
||||
:return: Configured router.
|
||||
:rtype: Router
|
||||
"""
|
||||
new = Router(
|
||||
hostname=cfg["hostname"],
|
||||
num_ports=cfg.get("num_ports"),
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
if "ports" in cfg:
|
||||
for port_num, port_cfg in cfg["ports"].items():
|
||||
new.configure_port(
|
||||
port=port_num,
|
||||
ip_address=port_cfg["ip_address"],
|
||||
subnet_mask=port_cfg["subnet_mask"],
|
||||
)
|
||||
if "acl" in cfg:
|
||||
new.acl._default_config = cfg["acl"] # save the config to allow resetting
|
||||
new.acl._reset_rules_to_default() # read the config and apply rules
|
||||
return new
|
||||
|
||||
@@ -84,6 +84,10 @@ class DatabaseService(Service):
|
||||
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
|
||||
|
||||
# send backup copy of database file to FTP server
|
||||
if not self.db_file:
|
||||
self.sys_log.error("Attempted to backup database file but it doesn't exist.")
|
||||
return False
|
||||
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server_ip,
|
||||
src_file_name=self.db_file.name,
|
||||
@@ -121,7 +125,7 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name="database", file_name="downloads.db")
|
||||
self.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
|
||||
|
||||
if self.db_file is None:
|
||||
|
||||
@@ -19,7 +19,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -491,6 +491,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -23,7 +23,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -502,6 +502,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -29,7 +29,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -509,6 +509,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -940,6 +957,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -27,7 +27,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -507,6 +507,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -23,7 +23,7 @@ game:
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
@@ -503,6 +503,23 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
|
||||
|
||||
def filter_keys_nested_item(data, keys):
|
||||
stack = [(data, {})]
|
||||
while stack:
|
||||
current, filtered = stack.pop()
|
||||
if isinstance(current, dict):
|
||||
for k, v in current.items():
|
||||
if k in keys:
|
||||
filtered[k] = filter_keys_nested_item(v, keys)
|
||||
elif isinstance(v, (dict, list)):
|
||||
stack.append((v, {}))
|
||||
elif isinstance(current, list):
|
||||
for item in current:
|
||||
stack.append((item, {}))
|
||||
return filtered
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def network(example_network) -> Network:
|
||||
assert len(example_network.routers) is 1
|
||||
@@ -59,10 +75,10 @@ def test_reset_network(network):
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
|
||||
state_before, sort_keys=True, indent=2
|
||||
)
|
||||
# don't worry if UUIDs change
|
||||
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
|
||||
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_creating_container():
|
||||
|
||||
Reference in New Issue
Block a user