Merged PR 263: Several hotfixes

## Summary
Hotfixes from 3.0.0b4, b5, b6. These have all gone thru the PR process already we just need to sync dev back.

## Test process
Merge conflicts resolved and all automated tests pass.

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [x] written **tests** for any new functionality added with this PR
- [~] updated the **documentation** if this PR changes or adds functionality
- [~] written/updated **design docs** if this PR implements new functionality
- [~] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Related work items: #2161, #2173, #2174, #2175, #2176, #2179, #2208, #2218, #2219, #2220
This commit is contained in:
Marek Wolan
2024-01-30 15:18:21 +00:00
38 changed files with 1308 additions and 337 deletions

1
.gitignore vendored
View File

@@ -156,4 +156,5 @@ benchmark/output
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py
sandbox.py
sandbox/
sandbox.ipynb

View File

@@ -6,8 +6,25 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Fixed a bug where ACL rules were not resetting on episode reset.
- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses
- Fixed a bug where deleted files and folders did not reset correctly on episode reset.
- Fixed a bug where service health status was using the actual health state instead of the visible health state
- Fixed a bug where the database file health status was using the incorrect value for negative rewards
- Fixed a bug preventing file actions from reaching their intended file
- Made database patch correctly take 2 timesteps instead of being immediate
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon.
- Fix a bug where ACLs were not showing up correctly in the observation space.
- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.
- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config.
- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config.
- Fixed an issue where the data manipulation attack was triggered at episode start.
- Fixed a bug where FTP STOR stored an additional copy on the client machine's filesystem
- Fixed a bug where the red agent acted to early
- Fixed the order of service health state
- Fixed an issue where starting a node didn't start the services on it
### Added

View File

@@ -33,7 +33,7 @@ Currently, the PrimAITE wheel can only be installed from GitHub. This may change
#### Windows (PowerShell)
**Prerequisites:**
* Manual install of Python >= 3.8 < 3.11
* Manual install of Python >= 3.8 < 3.12
**Install:**
@@ -56,7 +56,7 @@ primaite session
#### Unix
**Prerequisites:**
* Manual install of Python >= 3.8 < 3.11
* Manual install of Python >= 3.8 < 3.12
``` bash
sudo add-apt-repository ppa:deadsnakes/ppa
@@ -82,6 +82,7 @@ primaite session
```
### Developer Install from Source
To make your own changes to PrimAITE, perform the install from source (developer install)
@@ -138,3 +139,7 @@ make html
cd docs
.\make.bat html
```
## Example notebooks
Check out the example notebooks to learn more about how PrimAITE works and how you can use it to train agents. They are automatically copied to your primaite installation directory when you run `primaite setup`.

View File

@@ -1 +1 @@
3.0.0b4dev
3.0.0b6

View File

@@ -31,7 +31,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -112,10 +112,8 @@ agents:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: web_server_database_client
- service_name: web_server_web_service
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
@@ -306,63 +304,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -506,6 +504,24 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -25,7 +25,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:

View File

@@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb: str = "delete"
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
if node_uuid is None or folder_uuid is None or file_uuid is None:
return ["do_nothing"]
return ["do_nothing"]
# return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid]
class NodeFileRepairAction(NodeFileAbstractAction):
"""Action which repairs a file."""
@@ -443,30 +453,36 @@ class NetworkACLAddRuleAction(AbstractAction):
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if source_ip_id in [0, 1]:
if source_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif source_ip_id == 1:
src_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_port_id == 1:
if source_port_id == 0:
return ["do_nothing"] # invalid formulation
elif source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_ip_id in (0, 1):
if source_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
dst_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_port_id == 1:
if dest_port_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id)
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
@@ -914,6 +930,15 @@ class ActionManager:
:return: The constructed ActionManager.
:rtype: ActionManager
"""
ip_address_order = cfg["options"].pop("ip_address_order", {})
ip_address_list = []
for entry in ip_address_order:
node_ref = entry["node_ref"]
nic_num = entry["nic_num"]
node_obj = game.simulation.network.get_node_by_hostname(node_ref)
ip_address = node_obj.ethernet_port[nic_num].ip_address
ip_address_list.append(ip_address)
obj = cls(
game=game,
actions=cfg["action_list"],
@@ -921,7 +946,7 @@ class ActionManager:
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=None,
ip_address_list=ip_address_list or None,
act_map=cfg.get("action_map"),
)

View File

@@ -15,7 +15,6 @@ class DataManipulationAgent(AbstractScriptedAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _set_next_execution_timestep(self, timestep: int) -> None:
@@ -46,3 +45,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
def reset_agent_for_episode(self) -> None:
"""Set the next execution timestep when the episode resets."""
super().reset_agent_for_episode()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)

View File

@@ -136,6 +136,10 @@ class AbstractAgent(ABC):
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
def reset_agent_for_episode(self) -> None:
"""Agent reset logic should go here."""
pass
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""

View File

@@ -1,5 +1,6 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
@@ -78,7 +79,7 @@ class FileObservation(AbstractObservation):
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["health_status"]}
return {"health_status": file_state["visible_status"]}
@property
def space(self) -> spaces.Space:
@@ -204,12 +205,15 @@ class LinkObservation(AbstractObservation):
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 10) + 1
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": utilisation_category}}
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
@@ -554,7 +558,7 @@ class NodeObservation(AbstractObservation):
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
@@ -644,10 +648,13 @@ class AclObservation(AbstractObservation):
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
for i, rule_state in acl_state.items():
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i + 1] = {
"position": i,
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
@@ -656,15 +663,26 @@ class AclObservation(AbstractObservation):
"protocol": 0,
}
else:
obs[i + 1] = {
"position": i,
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
"source_port": self.port_to_id[rule_state["src_port"]],
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
"dest_port": self.port_to_id[rule_state["dst_port"]],
"protocol": self.protocol_to_id[rule_state["protocol"]],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property

View File

@@ -110,10 +110,17 @@ class DatabaseFileIntegrity(AbstractReward):
:type state: Dict
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
_LOGGER.info(
f"Could not calculate {self.__class__} reward because "
"simulation state did not contain enough information."
)
return 0.0
health_status = database_file_state["health_status"]
if health_status == "corrupted":
if health_status == 2:
return -1
elif health_status == "good":
elif health_status == 1:
return 1
else:
return 0
@@ -161,7 +168,6 @@ class WebServer404Penalty(AbstractReward):
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
if web_service_state is NOT_PRESENT_IN_STATE:
print("error getting web service state")
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.

View File

@@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.router import Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
@@ -117,7 +115,7 @@ class PrimaiteGame:
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
agent_actions = self.apply_agent_actions() # noqa
# Advance timestep
self.advance_timestep()
@@ -135,12 +133,15 @@ class PrimaiteGame:
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
agent_actions = {}
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
return agent_actions
def advance_timestep(self) -> None:
"""Advance timestep."""
@@ -164,6 +165,7 @@ class PrimaiteGame:
self.simulation.reset_component_for_episode(episode=self.episode_counter)
for agent in self.agents:
agent.reward_function.total_reward = 0.0
agent.reset_agent_for_episode()
def close(self) -> None:
"""Close the game, this will close the simulation."""
@@ -228,31 +230,7 @@ class PrimaiteGame:
operating_state=NodeOperatingState.ON,
)
elif n_type == "router":
new_node = Router(
hostname=node_cfg["hostname"],
num_ports=node_cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
# new_node.enable_port(port_num)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
# this: 'r_cfg.get('src_port')'
# Port/IPProtocol. TODO Refactor
new_node.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
position=r_num,
)
new_node = Router.from_config(node_cfg)
else:
_LOGGER.warning(f"invalid node type {n_type} in config")
if "services" in node_cfg:

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
agent_actions = self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
@@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env):
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {}
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
return next_obs, reward, terminated, truncated, info
@@ -172,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.apply_agent_actions()
agent_actions = self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -186,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {}
infos = {"agent_actions": agent_actions}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:

View File

@@ -44,3 +44,12 @@ def run(overwrite_existing: bool = True) -> None:
print(dst_fp)
shutil.copy2(src_fp, dst_fp)
_LOGGER.info(f"Reset example notebook: {dst_fp}")
for src_fp in primaite_root.glob("notebooks/_package_data/*"):
dst_fp = example_notebooks_user_dir / "_package_data" / src_fp.name
if should_copy_file(src_fp, dst_fp, overwrite_existing):
if not Path.exists(example_notebooks_user_dir / "_package_data/"):
Path.mkdir(example_notebooks_user_dir / "_package_data/")
print(dst_fp)
shutil.copy2(src_fp, dst_fp)
_LOGGER.info(f"Copied notebook resource to: {dst_fp}")

View File

@@ -23,7 +23,6 @@ class FileSystem(SimComponent):
"List containing all the folders in the file system."
deleted_folders: Dict[str, Folder] = {}
"List containing all the folders that have been deleted."
_folders_by_name: Dict[str, Folder] = {}
sys_log: SysLog
"Instance of SysLog used to create system logs."
sim_root: Path
@@ -56,7 +55,6 @@ class FileSystem(SimComponent):
folder = self.deleted_folders[uuid]
self.deleted_folders.pop(uuid)
self.folders[uuid] = folder
self._folders_by_name[folder.name] = folder
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
@@ -67,7 +65,6 @@ class FileSystem(SimComponent):
if uuid not in original_folder_uuids:
folder = self.folders[uuid]
self.folders.pop(uuid)
self._folders_by_name.pop(folder.name)
# Now reset all remaining folders
for folder in self.folders.values():
@@ -173,7 +170,6 @@ class FileSystem(SimComponent):
folder = Folder(name=folder_name, sys_log=self.sys_log)
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self._folder_request_manager.add_request(
name=folder.uuid, request_type=RequestType(func=folder._request_manager)
)
@@ -188,14 +184,13 @@ class FileSystem(SimComponent):
if folder_name == "root":
self.sys_log.warning("Cannot delete the root folder.")
return
folder = self._folders_by_name.get(folder_name)
folder = self.get_folder(folder_name)
if folder:
# set folder to deleted state
folder.delete()
# remove from folder list
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
# add to deleted list
folder.remove_all_files()
@@ -221,7 +216,10 @@ class FileSystem(SimComponent):
:param folder_name: The folder name.
:return: The matching Folder.
"""
return self._folders_by_name.get(folder_name)
for folder in self.folders.values():
if folder.name == folder_name:
return folder
return None
def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]:
"""
@@ -261,13 +259,13 @@ class FileSystem(SimComponent):
"""
if folder_name:
# check if file with name already exists
folder = self._folders_by_name.get(folder_name)
folder = self.get_folder(folder_name)
# If not then create it
if not folder:
folder = self.create_folder(folder_name)
else:
# Use root folder if folder_name not supplied
folder = self._folders_by_name["root"]
folder = self.get_folder("root")
# Create the file and add it to the folder
file = File(
@@ -474,7 +472,6 @@ class FileSystem(SimComponent):
folder.restore()
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
if folder.deleted:
self.deleted_folders.pop(folder.uuid)

View File

@@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent):
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:

View File

@@ -17,8 +17,6 @@ class Folder(FileSystemItemABC):
files: Dict[str, File] = {}
"Files stored in the folder."
_files_by_name: Dict[str, File] = {}
"Files by their name as <file name>.<file type>."
deleted_files: Dict[str, File] = {}
"Files that have been deleted."
@@ -78,7 +76,6 @@ class Folder(FileSystemItemABC):
file = self.deleted_files[uuid]
self.deleted_files.pop(uuid)
self.files[uuid] = file
self._files_by_name[file.name] = file
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
@@ -89,7 +86,6 @@ class Folder(FileSystemItemABC):
if uuid not in original_file_uuids:
file = self.files[uuid]
self.files.pop(uuid)
self._files_by_name.pop(file.name)
# Now reset all remaining files
for file in self.files.values():
@@ -105,7 +101,7 @@ class Folder(FileSystemItemABC):
self._file_request_manager = RequestManager()
rm.add_request(
name="file",
request_type=RequestType(func=lambda request, context: self._file_request_manager),
request_type=RequestType(func=self._file_request_manager),
)
return rm
@@ -219,7 +215,10 @@ class Folder(FileSystemItemABC):
:return: The matching File.
"""
# TODO: Increment read count?
return self._files_by_name.get(file_name)
for file in self.files.values():
if file.name == file_name:
return file
return None
def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File:
"""
@@ -250,15 +249,14 @@ class Folder(FileSystemItemABC):
raise Exception(f"Invalid file: {file}")
# check if file with id or name already exists in folder
if (force is not True) and file.name in self._files_by_name:
if self.get_file(file.name) is not None and not force:
raise Exception(f"File with name {file.name} already exists in folder")
if (force is not True) and file.uuid in self.files:
if (file.uuid in self.files) and not force:
raise Exception(f"File with uuid {file.uuid} already exists in folder")
# add to list
self.files[file.uuid] = file
self._files_by_name[file.name] = file
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
file.folder = self
@@ -275,11 +273,9 @@ class Folder(FileSystemItemABC):
if self.files.get(file.uuid):
self.files.pop(file.uuid)
self._files_by_name.pop(file.name)
self.deleted_files[file.uuid] = file
file.delete()
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
self._file_request_manager.remove_request(file.uuid)
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
@@ -300,7 +296,6 @@ class Folder(FileSystemItemABC):
self.deleted_files[file_id] = file
self.files = {}
self._files_by_name = {}
def restore_file(self, file_uuid: str):
"""
@@ -316,7 +311,6 @@ class Folder(FileSystemItemABC):
file.restore()
self.files[file.uuid] = file
self._files_by_name[file.name] = file
if file.deleted:
self.deleted_files.pop(file_uuid)

View File

@@ -1310,8 +1310,8 @@ class Node(SimComponent):
self.start_up_countdown = self.start_up_duration
if self.start_up_duration <= 0:
self._start_up_actions()
self.operating_state = NodeOperatingState.ON
self._start_up_actions()
self.sys_log.info("Turned on")
for nic in self.nics.values():
if nic._connected_link:

View File

@@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
@@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
class ACLAction(Enum):
"""Enum for defining the ACL action types."""
DENY = 0
PERMIT = 1
DENY = 2
class ACLRule(SimComponent):
@@ -65,11 +66,11 @@ class ACLRule(SimComponent):
"""
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["protocol"] = self.protocol.name if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["src_port"] = self.src_port.name if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
state["dst_port"] = self.dst_port.name if self.dst_port else None
return state
@@ -89,6 +90,8 @@ class AccessControlList(SimComponent):
implicit_rule: ACLRule
max_acl_rules: int = 25
_acl: List[Optional[ACLRule]] = [None] * 24
_default_config: Dict[int, dict] = {}
"""Config dict describing how the ACL list should look at episode start"""
def __init__(self, **kwargs) -> None:
if not kwargs.get("implicit_action"):
@@ -106,10 +109,40 @@ class AccessControlList(SimComponent):
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
for i, rule in enumerate(self._acl):
if not rule:
continue
self._default_config[i] = {"action": rule.action.name}
if rule.src_ip_address:
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
if rule.dst_ip_address:
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
if rule.src_port:
self._default_config[i]["src_port"] = rule.src_port.name
if rule.dst_port:
self._default_config[i]["dst_port"] = rule.dst_port.name
if rule.protocol:
self._default_config[i]["protocol"] = rule.protocol.name
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
self._reset_rules_to_default()
def _reset_rules_to_default(self) -> None:
"""Clear all ACL rules and set them to the default rules config."""
self._acl = [None] * (self.max_acl_rules - 1)
for r_num, r_cfg in self._default_config.items():
self.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -129,9 +162,9 @@ class AccessControlList(SimComponent):
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
None if request[1] == "ALL" else IPProtocol[request[1]],
IPv4Address(request[2]),
None if request[2] == "ALL" else IPv4Address(request[2]),
None if request[3] == "ALL" else Port[request[3]],
IPv4Address(request[4]),
None if request[4] == "ALL" else IPv4Address(request[4]),
None if request[5] == "ALL" else Port[request[5]],
int(request[6]),
)
@@ -385,7 +418,6 @@ class RouteTable(SimComponent):
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
@@ -811,8 +843,8 @@ class Router(Node):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = (self.num_ports,)
state["acl"] = (self.acl.describe_state(),)
state["num_ports"] = self.num_ports
state["acl"] = self.acl.describe_state()
return state
def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
@@ -991,3 +1023,63 @@ class Router(Node):
]
)
print(table)
@classmethod
def from_config(cls, cfg: dict) -> "Router":
"""Create a router based on a config dict.
Schema:
- hostname (str): unique name for this router.
- num_ports (int, optional): Number of network ports on the router. 8 by default
- ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying
ip_address and subnet_mask assigned to that ports (as strings)
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
where the rule will be added (lower number is resolved first). The values should describe valid ACL
Rules as:
- action (str): either PERMIT or DENY
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
Example config:
```
{
'hostname': 'router_1',
'num_ports': 5,
'ports': {
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
}
},
'acl' : {
21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'},
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
}
```
:param cfg: Router config adhering to schema described in main docstring body
:type cfg: dict
:return: Configured router.
:rtype: Router
"""
new = Router(
hostname=cfg["hostname"],
num_ports=cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
if "ports" in cfg:
for port_num, port_cfg in cfg["ports"].items():
new.configure_port(
port=port_num,
ip_address=port_cfg["ip_address"],
subnet_mask=port_cfg["subnet_mask"],
)
if "acl" in cfg:
new.acl._default_config = cfg["acl"] # save the config to allow resetting
new.acl._reset_rules_to_default() # read the config and apply rules
return new

View File

@@ -72,7 +72,7 @@ class DataManipulationBot(DatabaseClient):
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
return rm
@@ -83,7 +83,7 @@ class DataManipulationBot(DatabaseClient):
payload: Optional[str] = None,
port_scan_p_of_success: float = 0.1,
data_manipulation_p_of_success: float = 0.1,
repeat: bool = False,
repeat: bool = True,
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
@@ -168,6 +168,12 @@ class DataManipulationBot(DatabaseClient):
Calls the parent classes execute method before starting the application loop.
"""
super().run()
def attack(self):
"""Perform the attack steps after opening the application."""
if not self._can_perform_action():
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
self.run()
self._application_loop()
def _application_loop(self):
@@ -198,4 +204,4 @@ class DataManipulationBot(DatabaseClient):
:param timestep: The timestep value to update the bot's state.
"""
self._application_loop()
pass

View File

@@ -3,6 +3,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.file_system.folder import Folder
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
@@ -22,7 +24,7 @@ class DatabaseService(Service):
password: Optional[str] = None
backup_server: IPv4Address = None
backup_server_ip: IPv4Address = None
"""IP address of the backup server."""
latest_backup_directory: str = None
@@ -36,7 +38,6 @@ class DatabaseService(Service):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
def set_original_state(self):
@@ -45,8 +46,8 @@ class DatabaseService(Service):
super().set_original_state()
vals_to_include = {
"password",
"_connections",
"backup_server",
"connections",
"backup_server_ip",
"latest_backup_directory",
"latest_backup_file_name",
}
@@ -64,7 +65,7 @@ class DatabaseService(Service):
:param: backup_server_ip: The IP address of the backup server
"""
self.backup_server = backup_server
self.backup_server_ip = backup_server
def backup_database(self) -> bool:
"""Create a backup of the database to the configured backup server."""
@@ -73,7 +74,7 @@ class DatabaseService(Service):
return False
# check if the backup server was configured
if self.backup_server is None:
if self.backup_server_ip is None:
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
@@ -81,10 +82,14 @@ class DatabaseService(Service):
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# send backup copy of database file to FTP server
if not self.db_file:
self.sys_log.error("Attempted to backup database file but it doesn't exist.")
return False
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server,
src_file_name=self._db_file.name,
src_folder_name=self.folder.name,
dest_ip_address=self.backup_server_ip,
src_file_name=self.db_file.name,
src_folder_name="database",
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
)
@@ -110,7 +115,7 @@ class DatabaseService(Service):
src_file_name="database.db",
dest_folder_name="downloads",
dest_file_name="database.db",
dest_ip_address=self.backup_server,
dest_ip_address=self.backup_server_ip,
)
if not response:
@@ -118,13 +123,10 @@ class DatabaseService(Service):
return False
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self.file_system.delete_file(folder_name="database", file_name="database.db")
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
if self._db_file is None:
if self.db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
@@ -134,12 +136,30 @@ class DatabaseService(Service):
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
self.file_system.create_file(folder_name="database", file_name="database.db")
@property
def db_file(self) -> File:
"""Returns the database file."""
return self.file_system.get_file(folder_name="database", file_name="database.db")
@property
def folder(self) -> Folder:
"""Returns the database folder."""
return self.file_system.get_folder_by_id(self.db_file.folder_id)
def _process_connect(
self, connection_id: str, password: Optional[str] = None
) -> Dict[str, Union[int, Dict[str, bool]]]:
"""Process an incoming connection request.
:param connection_id: A unique identifier for the connection
:type connection_id: str
:param password: Supplied password. It must match self.password for connection success, defaults to None
:type password: Optional[str], optional
:return: Response to connection request containing success info.
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
status_code = 500 # Default internal server error
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
@@ -184,7 +204,7 @@ class DatabaseService(Service):
self.sys_log.info(f"{self.name}: Running {query}")
if query == "SELECT":
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {
"status_code": 200,
"type": "sql",
@@ -195,17 +215,8 @@ class DatabaseService(Service):
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
if self.health_state_actual == SoftwareHealthState.GOOD:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 404, "data": False}
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id}
else:
# Invalid query
return {"status_code": 500, "data": False}
@@ -265,3 +276,19 @@ class DatabaseService(Service):
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return payload["status_code"] == 200
def apply_timestep(self, timestep: int) -> None:
"""
Apply a single timestep of simulation dynamics to this service.
Here at the first step, the database backup is created, in addition to normal service update logic.
"""
if timestep == 1:
self.backup_database()
return super().apply_timestep(timestep)
def _update_patch_status(self) -> None:
"""Perform a database restore when the patching countdown is finished."""
super()._update_patch_status()
if self._patching_countdown is None:
self.restore_backup()

View File

@@ -89,6 +89,7 @@ class FTPClient(FTPServiceABC):
f"{self.name}: Successfully connected to FTP Server "
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
)
self.add_connection(connection_id="server_connection", session_id=session_id)
return True
else:
if is_reattempt:

View File

@@ -99,5 +99,5 @@ class FTPServer(FTPServiceABC):
if payload.status_code is not None:
return False
self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id)
self._process_ftp_command(payload=payload, session_id=session_id)
return True

View File

@@ -56,10 +56,12 @@ class FTPServiceABC(Service, ABC):
folder_name = payload.ftp_command_args["dest_folder_name"]
file_size = payload.ftp_command_args["file_size"]
real_file_path = payload.ftp_command_args.get("real_file_path")
health_status = payload.ftp_command_args["health_status"]
is_real = real_file_path is not None
file = self.file_system.create_file(
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
)
file.health_status = health_status
self.sys_log.info(
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
f"{payload.ftp_command_args['dest_file_name']}"
@@ -114,6 +116,7 @@ class FTPServiceABC(Service, ABC):
"dest_file_name": dest_file_name,
"file_size": file.sim_size,
"real_file_path": file.sim_path if file.real else None,
"health_status": file.health_status,
},
packet_payload_size=file.sim_size,
status_code=FTPStatusCode.OK if is_response else None,

View File

@@ -13,6 +13,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
@@ -123,7 +124,10 @@ class WebServer(Service):
# get all users
if db_client.query("SELECT"):
# query succeeded
self.set_health_state(SoftwareHealthState.GOOD)
response.status_code = HttpStatusCode.OK
else:
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception:

View File

@@ -38,12 +38,12 @@ class SoftwareHealthState(Enum):
"Unused state."
GOOD = 1
"The software is in a good and healthy condition."
COMPROMISED = 2
"The software's security has been compromised."
OVERWHELMED = 3
"he software is overwhelmed and not functioning properly."
PATCHING = 4
PATCHING = 2
"The software is undergoing patching or updates."
COMPROMISED = 3
"The software's security has been compromised."
OVERWHELMED = 4
"he software is overwhelmed and not functioning properly."
class SoftwareCriticality(Enum):
@@ -195,8 +195,9 @@ class Software(SimComponent):
def patch(self) -> None:
"""Perform a patch on the software."""
self._patching_countdown = self.patching_duration
self.set_health_state(SoftwareHealthState.PATCHING)
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
self._patching_countdown = self.patching_duration
self.set_health_state(SoftwareHealthState.PATCHING)
def _update_patch_status(self) -> None:
"""Update the patch status of the software."""

View File

@@ -19,7 +19,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -491,6 +491,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -23,7 +23,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -495,6 +495,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -29,7 +29,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -502,6 +502,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:
@@ -933,6 +950,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -27,7 +27,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -500,6 +500,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -23,7 +23,7 @@ game:
- UDP
agents:
- ref: client_1_green_user
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
@@ -501,6 +501,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -22,7 +22,7 @@ def test_data_manipulation(uc2_network):
assert db_client.query("SELECT")
# Now we run the DataManipulationBot
db_manipulation_bot.run()
db_manipulation_bot.attack()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT")

View File

@@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
from primaite.simulator.system.services.database.database_service import DatabaseService
def filter_keys_nested_item(data, keys):
stack = [(data, {})]
while stack:
current, filtered = stack.pop()
if isinstance(current, dict):
for k, v in current.items():
if k in keys:
filtered[k] = filter_keys_nested_item(v, keys)
elif isinstance(v, (dict, list)):
stack.append((v, {}))
elif isinstance(current, list):
for item in current:
stack.append((item, {}))
return filtered
@pytest.fixture(scope="function")
def network(example_network) -> Network:
assert len(example_network.routers) is 1
@@ -59,10 +75,10 @@ def test_reset_network(network):
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
state_before, sort_keys=True, indent=2
)
# don't worry if UUIDs change
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
assert a == b
def test_creating_container():

View File

@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
import pytest
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
@@ -42,6 +43,7 @@ def test_ftp_client_store_file(ftp_client):
"dest_folder_name": "downloads",
"dest_file_name": "file.txt",
"file_size": 24,
"health_status": FileSystemItemHealthStatus.GOOD,
},
packet_payload_size=24,
status_code=FTPStatusCode.OK,

View File

@@ -1,5 +1,6 @@
import pytest
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
@@ -41,6 +42,7 @@ def test_ftp_server_store_file(ftp_server):
"dest_folder_name": "downloads",
"dest_file_name": "file.txt",
"file_size": 24,
"health_status": FileSystemItemHealthStatus.GOOD,
},
packet_payload_size=24,
)