Merged PR 260: Beta 6 Fixes

Please run some of these changes locally. Read the notebook, check that it makes sense, and run the code cells to see if they produce the result you expect.

## Summary
Apologies that all these fixes are part of 1 massive PR instead of individual PRs. I thought it was going to be a quick job and it spiralled out of control.

Changes:
- Fixed a bug where ACL rules were not resetting on episode reset.
- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses
- Fixed a bug where deleted files and folders did not reset correctly on episode reset.
- Fixed a bug where service health status was using the actual health state instead of the visible health state
- Fixed a bug where the database file health status was using the incorrect value for negative rewards
- Fixed a bug preventing file actions from reaching their intended file
- Made database patch correctly take 2 timesteps instead of being immediate
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon.
- Fix a bug where ACLs were not showing up correctly in the observation space.
- Added a recap of agent actions to the `info` output of `step()`
- Added a notebook which explains UC2, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.

## Test process
New notebook verifies end-to-end UC2 functionality.

## Checklist
- [y] PR is linked to a **work item**
- [y] **acceptance criteria** of linked ticket are met
- [y] performed **self-review** of the code
- [~] written **tests** for any new functionality added with this PR
- [y] updated the **documentation** if this PR changes or adds functionality
- [n] written/updated **design docs** if this PR implements new functionality
- [y] updated the **change log**
- [y] ran **pre-commit** checks for code style
- [y] attended to any **TO-DOs** left in the code

Related work items: #2208, #2218, #2219, #2220
This commit is contained in:
Marek Wolan
2024-01-25 15:15:48 +00:00
24 changed files with 1166 additions and 271 deletions

1
.gitignore vendored
View File

@@ -156,3 +156,4 @@ benchmark/output
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py
sandbox.py
sandbox/

View File

@@ -6,6 +6,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Fixed a bug where ACL rules were not resetting on episode reset.
- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses
- Fixed a bug where deleted files and folders did not reset correctly on episode reset.
- Fixed a bug where service health status was using the actual health state instead of the visible health state
- Fixed a bug where the database file health status was using the incorrect value for negative rewards
- Fixed a bug preventing file actions from reaching their intended file
- Made database patch correctly take 2 timesteps instead of being immediate
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon.
- Fix a bug where ACLs were not showing up correctly in the observation space.
- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.
- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config.
- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config.
- Fixed an issue where the data manipulation attack was triggered at episode start.

View File

@@ -1 +1 @@
3.0.0b5
3.0.0b6dev

View File

@@ -31,7 +31,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -304,63 +304,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -504,6 +504,24 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -25,7 +25,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:

View File

@@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb: str = "delete"
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
if node_uuid is None or folder_uuid is None or file_uuid is None:
return ["do_nothing"]
return ["do_nothing"]
# return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid]
class NodeFileRepairAction(NodeFileAbstractAction):
"""Action which repairs a file."""
@@ -460,13 +470,13 @@ class NetworkACLAddRuleAction(AbstractAction):
dst_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id)
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
@@ -914,6 +924,15 @@ class ActionManager:
:return: The constructed ActionManager.
:rtype: ActionManager
"""
ip_address_order = cfg["options"].pop("ip_address_order", {})
ip_address_list = []
for entry in ip_address_order:
node_ref = entry["node_ref"]
nic_num = entry["nic_num"]
node_obj = game.simulation.network.get_node_by_hostname(node_ref)
ip_address = node_obj.ethernet_port[nic_num].ip_address
ip_address_list.append(ip_address)
obj = cls(
game=game,
actions=cfg["action_list"],
@@ -921,7 +940,7 @@ class ActionManager:
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=None,
ip_address_list=ip_address_list or None,
act_map=cfg.get("action_map"),
)

View File

@@ -1,5 +1,6 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
@@ -648,10 +649,13 @@ class AclObservation(AbstractObservation):
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
for i, rule_state in acl_state.items():
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i + 1] = {
"position": i,
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
@@ -660,15 +664,26 @@ class AclObservation(AbstractObservation):
"protocol": 0,
}
else:
obs[i + 1] = {
"position": i,
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
"source_port": self.port_to_id[rule_state["src_port"]],
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
"dest_port": self.port_to_id[rule_state["dst_port"]],
"protocol": self.protocol_to_id[rule_state["protocol"]],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property

View File

@@ -110,6 +110,13 @@ class DatabaseFileIntegrity(AbstractReward):
:type state: Dict
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
_LOGGER.info(
f"Could not calculate {self.__class__} reward because "
"simulation state did not contain enough information."
)
return 0.0
health_status = database_file_state["health_status"]
if health_status == 2:
return -1

View File

@@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.router import Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
@@ -115,7 +113,7 @@ class PrimaiteGame:
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
agent_actions = self.apply_agent_actions() # noqa
# Advance timestep
self.advance_timestep()
@@ -133,12 +131,15 @@ class PrimaiteGame:
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
agent_actions = {}
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
return agent_actions
def advance_timestep(self) -> None:
"""Advance timestep."""
@@ -227,31 +228,7 @@ class PrimaiteGame:
operating_state=NodeOperatingState.ON,
)
elif n_type == "router":
new_node = Router(
hostname=node_cfg["hostname"],
num_ports=node_cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
# new_node.enable_port(port_num)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
# this: 'r_cfg.get('src_port')'
# Port/IPProtocol. TODO Refactor
new_node.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
position=r_num,
)
new_node = Router.from_config(node_cfg)
else:
_LOGGER.warning(f"invalid node type {n_type} in config")
if "services" in node_cfg:

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
agent_actions = self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
@@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env):
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {}
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
return next_obs, reward, terminated, truncated, info
@@ -172,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.apply_agent_actions()
agent_actions = self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -186,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {}
infos = {"agent_actions": agent_actions}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:

View File

@@ -23,7 +23,6 @@ class FileSystem(SimComponent):
"List containing all the folders in the file system."
deleted_folders: Dict[str, Folder] = {}
"List containing all the folders that have been deleted."
_folders_by_name: Dict[str, Folder] = {}
sys_log: SysLog
"Instance of SysLog used to create system logs."
sim_root: Path
@@ -56,7 +55,6 @@ class FileSystem(SimComponent):
folder = self.deleted_folders[uuid]
self.deleted_folders.pop(uuid)
self.folders[uuid] = folder
self._folders_by_name[folder.name] = folder
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
@@ -67,7 +65,6 @@ class FileSystem(SimComponent):
if uuid not in original_folder_uuids:
folder = self.folders[uuid]
self.folders.pop(uuid)
self._folders_by_name.pop(folder.name)
# Now reset all remaining folders
for folder in self.folders.values():
@@ -173,7 +170,6 @@ class FileSystem(SimComponent):
folder = Folder(name=folder_name, sys_log=self.sys_log)
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self._folder_request_manager.add_request(
name=folder.uuid, request_type=RequestType(func=folder._request_manager)
)
@@ -188,14 +184,13 @@ class FileSystem(SimComponent):
if folder_name == "root":
self.sys_log.warning("Cannot delete the root folder.")
return
folder = self._folders_by_name.get(folder_name)
folder = self.get_folder(folder_name)
if folder:
# set folder to deleted state
folder.delete()
# remove from folder list
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
# add to deleted list
folder.remove_all_files()
@@ -221,7 +216,10 @@ class FileSystem(SimComponent):
:param folder_name: The folder name.
:return: The matching Folder.
"""
return self._folders_by_name.get(folder_name)
for folder in self.folders.values():
if folder.name == folder_name:
return folder
return None
def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]:
"""
@@ -261,13 +259,13 @@ class FileSystem(SimComponent):
"""
if folder_name:
# check if file with name already exists
folder = self._folders_by_name.get(folder_name)
folder = self.get_folder(folder_name)
# If not then create it
if not folder:
folder = self.create_folder(folder_name)
else:
# Use root folder if folder_name not supplied
folder = self._folders_by_name["root"]
folder = self.get_folder("root")
# Create the file and add it to the folder
file = File(
@@ -474,7 +472,6 @@ class FileSystem(SimComponent):
folder.restore()
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
if folder.deleted:
self.deleted_folders.pop(folder.uuid)

View File

@@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent):
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:

View File

@@ -17,8 +17,6 @@ class Folder(FileSystemItemABC):
files: Dict[str, File] = {}
"Files stored in the folder."
_files_by_name: Dict[str, File] = {}
"Files by their name as <file name>.<file type>."
deleted_files: Dict[str, File] = {}
"Files that have been deleted."
@@ -78,7 +76,6 @@ class Folder(FileSystemItemABC):
file = self.deleted_files[uuid]
self.deleted_files.pop(uuid)
self.files[uuid] = file
self._files_by_name[file.name] = file
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
@@ -89,7 +86,6 @@ class Folder(FileSystemItemABC):
if uuid not in original_file_uuids:
file = self.files[uuid]
self.files.pop(uuid)
self._files_by_name.pop(file.name)
# Now reset all remaining files
for file in self.files.values():
@@ -219,7 +215,10 @@ class Folder(FileSystemItemABC):
:return: The matching File.
"""
# TODO: Increment read count?
return self._files_by_name.get(file_name)
for file in self.files.values():
if file.name == file_name:
return file
return None
def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File:
"""
@@ -250,15 +249,14 @@ class Folder(FileSystemItemABC):
raise Exception(f"Invalid file: {file}")
# check if file with id or name already exists in folder
if (force is not True) and file.name in self._files_by_name:
if self.get_file(file.name) is not None and not force:
raise Exception(f"File with name {file.name} already exists in folder")
if (force is not True) and file.uuid in self.files:
if (file.uuid in self.files) and not force:
raise Exception(f"File with uuid {file.uuid} already exists in folder")
# add to list
self.files[file.uuid] = file
self._files_by_name[file.name] = file
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
file.folder = self
@@ -275,11 +273,10 @@ class Folder(FileSystemItemABC):
if self.files.get(file.uuid):
self.files.pop(file.uuid)
self._files_by_name.pop(file.name)
self.deleted_files[file.uuid] = file
file.delete()
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
self._file_request_manager.remove_request(file.uuid)
# self._file_request_manager.remove_request(file.uuid)
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
@@ -300,7 +297,6 @@ class Folder(FileSystemItemABC):
self.deleted_files[file_id] = file
self.files = {}
self._files_by_name = {}
def restore_file(self, file_uuid: str):
"""
@@ -316,7 +312,6 @@ class Folder(FileSystemItemABC):
file.restore()
self.files[file.uuid] = file
self._files_by_name[file.name] = file
if file.deleted:
self.deleted_files.pop(file_uuid)

View File

@@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
@@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
class ACLAction(Enum):
"""Enum for defining the ACL action types."""
DENY = 0
PERMIT = 1
DENY = 2
class ACLRule(SimComponent):
@@ -65,11 +66,11 @@ class ACLRule(SimComponent):
"""
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["protocol"] = self.protocol.name if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["src_port"] = self.src_port.name if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
state["dst_port"] = self.dst_port.name if self.dst_port else None
return state
@@ -89,6 +90,8 @@ class AccessControlList(SimComponent):
implicit_rule: ACLRule
max_acl_rules: int = 25
_acl: List[Optional[ACLRule]] = [None] * 24
_default_config: Dict[int, dict] = {}
"""Config dict describing how the ACL list should look at episode start"""
def __init__(self, **kwargs) -> None:
if not kwargs.get("implicit_action"):
@@ -106,10 +109,40 @@ class AccessControlList(SimComponent):
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
for i, rule in enumerate(self._acl):
if not rule:
continue
self._default_config[i] = {"action": rule.action.name}
if rule.src_ip_address:
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
if rule.dst_ip_address:
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
if rule.src_port:
self._default_config[i]["src_port"] = rule.src_port.name
if rule.dst_port:
self._default_config[i]["dst_port"] = rule.dst_port.name
if rule.protocol:
self._default_config[i]["protocol"] = rule.protocol.name
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
self._reset_rules_to_default()
def _reset_rules_to_default(self) -> None:
"""Clear all ACL rules and set them to the default rules config."""
self._acl = [None] * (self.max_acl_rules - 1)
for r_num, r_cfg in self._default_config.items():
self.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -391,7 +424,6 @@ class RouteTable(SimComponent):
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
@@ -716,8 +748,8 @@ class Router(Node):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = (self.num_ports,)
state["acl"] = (self.acl.describe_state(),)
state["num_ports"] = self.num_ports
state["acl"] = self.acl.describe_state()
return state
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
@@ -864,3 +896,63 @@ class Router(Node):
]
)
print(table)
@classmethod
def from_config(cls, cfg: dict) -> "Router":
"""Create a router based on a config dict.
Schema:
- hostname (str): unique name for this router.
- num_ports (int, optional): Number of network ports on the router. 8 by default
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
ip_address and subnet_mask assigned to that ports (as strings)
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
where the rule will be added (lower number is resolved first). The values should describe valid ACL
Rules as:
- action (str): either PERMIT or DENY
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
Example config:
```
{
'hostname': 'router_1',
'num_ports': 5,
'ports': {
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
}
},
'acl' : {
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
}
```
:param cfg: Router config adhering to schema described in main docstring body
:type cfg: dict
:return: Configured router.
:rtype: Router
"""
new = Router(
hostname=cfg["hostname"],
num_ports=cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
new.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=port_cfg["subnet_mask"],
)
if "acl" in cfg:
new.acl._default_config = cfg["acl"] # save the config to allow resetting
new.acl._reset_rules_to_default() # read the config and apply rules
return new

View File

@@ -84,6 +84,10 @@ class DatabaseService(Service):
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# send backup copy of database file to FTP server
if not self.db_file:
self.sys_log.error("Attempted to backup database file but it doesn't exist.")
return False
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server_ip,
src_file_name=self.db_file.name,
@@ -121,7 +125,7 @@ class DatabaseService(Service):
return False
# replace db file
self.file_system.delete_file(folder_name="database", file_name="downloads.db")
self.file_system.delete_file(folder_name="database", file_name="database.db")
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
if self.db_file is None:

View File

@@ -19,7 +19,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -491,6 +491,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -23,7 +23,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -502,6 +502,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -29,7 +29,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -509,6 +509,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:
@@ -940,6 +957,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -27,7 +27,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -507,6 +507,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -23,7 +23,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -503,6 +503,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
from primaite.simulator.system.services.database.database_service import DatabaseService
def filter_keys_nested_item(data, keys):
stack = [(data, {})]
while stack:
current, filtered = stack.pop()
if isinstance(current, dict):
for k, v in current.items():
if k in keys:
filtered[k] = filter_keys_nested_item(v, keys)
elif isinstance(v, (dict, list)):
stack.append((v, {}))
elif isinstance(current, list):
for item in current:
stack.append((item, {}))
return filtered
@pytest.fixture(scope="function")
def network(example_network) -> Network:
assert len(example_network.routers) is 1
@@ -59,10 +75,10 @@ def test_reset_network(network):
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
state_before, sort_keys=True, indent=2
)
# don't worry if UUIDs change
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
assert a == b
def test_creating_container():