diff --git a/.gitignore b/.gitignore index 8be60770..b3d9682a 100644 --- a/.gitignore +++ b/.gitignore @@ -156,4 +156,5 @@ benchmark/output # src/primaite/notebooks/scratch.ipynb src/primaite/notebooks/scratch.py sandbox.py +sandbox/ sandbox.ipynb diff --git a/CHANGELOG.md b/CHANGELOG.md index 541a39d5..5d8706ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,24 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Fixed a bug where ACL rules were not resetting on episode reset. +- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses +- Fixed a bug where deleted files and folders did not reset correctly on episode reset. +- Fixed a bug where service health status was using the actual health state instead of the visible health state +- Fixed a bug where the database file health status was using the incorrect value for negative rewards +- Fixed a bug preventing file actions from reaching their intended file +- Made database patch correctly take 2 timesteps instead of being immediate +- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING +- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon. +- Fix a bug where ACLs were not showing up correctly in the observation space. +- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function. +- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config. +- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config. +- Fixed an issue where the data manipulation attack was triggered at episode start. +- Fixed a bug where FTP STOR stored an additional copy on the client machine's filesystem +- Fixed a bug where the red agent acted to early +- Fixed the order of service health state +- Fixed an issue where starting a node didn't start the services on it @@ -38,6 +56,18 @@ SessionManager. - HTTP Services: `WebBrowser` to simulate a web client and `WebServer` - Fixed an issue where the services were still able to run even though the node the service is installed on is turned off - NTP Services: `NTPClient` and `NTPServer` +- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic. + - **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required. + - **Enhanced Frame Reception**: The `receive_frame` method in `RouterNIC` is tailored to handle frames based on Layer 2 (Ethernet) checks, focusing on MAC address-based routing and broadcast frame acceptance. +- **Subnet-Wide Broadcasting for Services and Applications**: Implemented the ability for services and applications to conduct broadcasts across an entire IPv4 subnet within the network simulation framework. + +### Changed +- Integrated the RouteTable into the Routers frame processing. +- Frames are now dropped when their TTL reaches 0 +- **NIC Functionality Update**: Updated the Network Interface Card (`NIC`) functionality to support Layer 3 (L3) broadcasts. + - **Layer 3 Broadcast Handling**: Enhanced the existing `NIC` classes to correctly process and handle Layer 3 broadcasts. This update allows devices using standard NICs to effectively participate in network activities that involve L3 broadcasting. + - **Improved Frame Reception Logic**: The `receive_frame` method of the `NIC` class has been updated to include additional checks and handling for L3 broadcasts, ensuring proper frame processing in a wider range of network scenarios. + ### Removed - Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol` diff --git a/README.md b/README.md index ec335108..7dfe15bd 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Currently, the PrimAITE wheel can only be installed from GitHub. This may change #### Windows (PowerShell) **Prerequisites:** -* Manual install of Python >= 3.8 < 3.11 +* Manual install of Python >= 3.8 < 3.12 **Install:** @@ -56,7 +56,7 @@ primaite session #### Unix **Prerequisites:** -* Manual install of Python >= 3.8 < 3.11 +* Manual install of Python >= 3.8 < 3.12 ``` bash sudo add-apt-repository ppa:deadsnakes/ppa @@ -82,6 +82,7 @@ primaite session ``` + ### Developer Install from Source To make your own changes to PrimAITE, perform the install from source (developer install) @@ -138,3 +139,7 @@ make html cd docs .\make.bat html ``` + + +## Example notebooks +Check out the example notebooks to learn more about how PrimAITE works and how you can use it to train agents. They are automatically copied to your primaite installation directory when you run `primaite setup`. diff --git a/docs/source/config.rst b/docs/source/config.rst index f4452c7e..23bf6097 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -13,7 +13,25 @@ This section allows selecting which training framework and algorithm to use, and ``io_settings`` --------------- -This section configures how the ``PrimaiteSession`` saves data. +This section configures how PrimAITE saves data during simulation and training. + +**save_final_model**: Only used if training with PrimaiteSession, if true, the policy will be saved after the final training iteration. + +**save_checkpoints**: Only used if training with PrimaiteSession, if true, the policy will be saved periodically during training. + +**checkpoint_interval**: Only used if training with PrimaiteSession and if ``save_checkpoints`` is true. Defines how often to save the policy during training. + +**save_logs**: *currently unused*. + +**save_transactions**: *currently unused*. + +**save_tensorboard_logs**: *currently unused*. + +**save_step_metadata**: Whether to save the RL agents' action, environment state, and other data at every single step. + +**save_pcap_logs**: Whether to save pcap files of all network traffic during the simulation. + +**save_sys_logs**: Whether to save system logs from all nodes during the simulation. ``game`` -------- @@ -56,6 +74,10 @@ Description of configurable items: **agent_settings**: Settings passed to the agent during initialisation. These depend on the agent class. +Reinforcement learning agents use the ``ProxyAgent`` class, they accept these agent settings: + +**flatten_obs**: If true, gymnasium flattening will be performed on the observation space before sending to the agent. Set this to true if your agent does not support nested observation spaces. + ``simulation`` -------------- In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 0fd919fd..43662e8c 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -3.0.0b4dev +3.0.0b6 diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 3c4ac62b..68aa9106 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -14,6 +14,8 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true game: @@ -29,7 +31,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -110,10 +112,8 @@ agents: - service_name: DNSServer - node_hostname: web_server services: - - service_name: DatabaseClient + - service_name: web_server_web_service - node_hostname: database_server - services: - - service_name: DatabaseService folders: - folder_name: database files: @@ -302,63 +302,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite) action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite) action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -507,6 +507,24 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 + reward_function: reward_components: @@ -526,7 +544,7 @@ agents: agent_settings: - # ... + flatten_obs: true diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index c1e2ea81..6aa54487 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -25,7 +25,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index d99d3818..40c40077 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -296,6 +296,16 @@ class NodeFileDeleteAction(NodeFileAbstractAction): super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs) self.verb: str = "delete" + def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id) + file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id) + if node_uuid is None or folder_uuid is None or file_uuid is None: + return ["do_nothing"] + return ["do_nothing"] + # return ["network", "node", node_uuid, "file_system", "delete", "file", folder_uuid, file_uuid] + class NodeFileRepairAction(NodeFileAbstractAction): """Action which repairs a file.""" @@ -443,27 +453,33 @@ class NetworkACLAddRuleAction(AbstractAction): protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2) # subtract 2 to account for UNUSED=0 and ALL=1. - if source_ip_id in [0, 1]: + if source_ip_id == 0: + return ["do_nothing"] # invalid formulation + elif source_ip_id == 1: src_ip = "ALL" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS else: src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 - if source_port_id == 1: + if source_port_id == 0: + return ["do_nothing"] # invalid formulation + elif source_port_id == 1: src_port = "ALL" else: src_port = self.manager.get_port_by_idx(source_port_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 - if dest_ip_id in (0, 1): + if source_ip_id == 0: + return ["do_nothing"] # invalid formulation + elif dest_ip_id == 1: dst_ip = "ALL" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS else: dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 - if dest_port_id == 1: + if dest_port_id == 0: + return ["do_nothing"] # invalid formulation + elif dest_port_id == 1: dst_port = "ALL" else: dst_port = self.manager.get_port_by_idx(dest_port_id - 2) @@ -943,13 +959,22 @@ class ActionManager: :return: The constructed ActionManager. :rtype: ActionManager """ + ip_address_order = cfg["options"].pop("ip_address_order", {}) + ip_address_list = [] + for entry in ip_address_order: + node_ref = entry["node_ref"] + nic_num = entry["nic_num"] + node_obj = game.simulation.network.get_node_by_hostname(node_ref) + ip_address = node_obj.ethernet_port[nic_num].ip_address + ip_address_list.append(ip_address) + obj = cls( game=game, actions=cfg["action_list"], **cfg["options"], protocols=game.options.protocols, ports=game.options.ports, - ip_address_list=None, + ip_address_list=ip_address_list or None, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 791c362d..58b790ec 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -15,7 +15,6 @@ class DataManipulationAgent(AbstractScriptedAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) def _set_next_execution_timestep(self, timestep: int) -> None: @@ -46,3 +45,8 @@ class DataManipulationAgent(AbstractScriptedAgent): self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} + + def reset_agent_for_episode(self) -> None: + """Set the next execution timestep when the episode resets.""" + super().reset_agent_for_episode() + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index fbbe5473..276715f7 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -44,6 +44,8 @@ class AgentSettings(BaseModel): start_settings: Optional[AgentStartSettings] = None "Configuration for when an agent begins performing it's actions" + flatten_obs: bool = True + "Whether to flatten the observation space before passing it to the agent. True by default." @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": @@ -134,6 +136,10 @@ class AbstractAgent(ABC): request = self.action_manager.form_request(action_identifier=action, action_options=options) return request + def reset_agent_for_episode(self) -> None: + """Agent reset logic should go here.""" + pass + class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" @@ -166,6 +172,7 @@ class ProxyAgent(AbstractAgent): action_space: Optional[ActionManager], observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], + agent_settings: Optional[AgentSettings] = None, ) -> None: super().__init__( agent_name=agent_name, @@ -174,6 +181,7 @@ class ProxyAgent(AbstractAgent): reward_function=reward_function, ) self.most_recent_action: ActType + self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index dfb506c5..ea638378 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -1,5 +1,6 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod +from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces @@ -78,7 +79,7 @@ class FileObservation(AbstractObservation): file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: return self.default_observation - return {"health_status": file_state["health_status"]} + return {"health_status": file_state["visible_status"]} @property def space(self) -> spaces.Space: @@ -204,12 +205,15 @@ class LinkObservation(AbstractObservation): bandwidth = link_state["bandwidth"] load = link_state["current_load"] - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 10) + 1 + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 9) + 1 # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": utilisation_category}} + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} @property def space(self) -> spaces.Space: @@ -554,7 +558,7 @@ class NodeObservation(AbstractObservation): folder_configs = config.get("folders", {}) folders = [ FolderObservation.from_config( - config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder + config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder ) for c in folder_configs ] @@ -644,10 +648,13 @@ class AclObservation(AbstractObservation): # TODO: what if the ACL has more rules than num of max rules for obs space obs = {} - for i, rule_state in acl_state.items(): + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] if rule_state is None: - obs[i + 1] = { - "position": i, + obs[i] = { + "position": i - 1, "permission": 0, "source_node_id": 0, "source_port": 0, @@ -656,15 +663,26 @@ class AclObservation(AbstractObservation): "protocol": 0, } else: - obs[i + 1] = { - "position": i, + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, "permission": rule_state["action"], - "source_node_id": self.node_to_id[rule_state["src_ip_address"]], - "source_port": self.port_to_id[rule_state["src_port"]], - "dest_node_id": self.node_to_id[rule_state["dst_ip_address"]], - "dest_port": self.port_to_id[rule_state["dst_port"]], - "protocol": self.protocol_to_id[rule_state["protocol"]], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, } + i += 1 return obs @property diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 66cbbb45..0b292bcb 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -110,10 +110,17 @@ class DatabaseFileIntegrity(AbstractReward): :type state: Dict """ database_file_state = access_from_nested_dict(state, self.location_in_state) + if database_file_state is NOT_PRESENT_IN_STATE: + _LOGGER.info( + f"Could not calculate {self.__class__} reward because " + "simulation state did not contain enough information." + ) + return 0.0 + health_status = database_file_state["health_status"] - if health_status == "corrupted": + if health_status == 2: return -1 - elif health_status == "good": + elif health_status == 1: return 1 else: return 0 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 08677765..368d899a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -13,11 +13,9 @@ from primaite.game.agent.rewards import RewardFunction from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer -from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.router import Router from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot @@ -117,7 +115,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - self.apply_agent_actions() + agent_actions = self.apply_agent_actions() # noqa # Advance timestep self.advance_timestep() @@ -135,12 +133,15 @@ class PrimaiteGame: def apply_agent_actions(self) -> None: """Apply all actions to simulation as requests.""" + agent_actions = {} for agent in self.agents: obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) + agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) + return agent_actions def advance_timestep(self) -> None: """Advance timestep.""" @@ -164,6 +165,7 @@ class PrimaiteGame: self.simulation.reset_component_for_episode(episode=self.episode_counter) for agent in self.agents: agent.reward_function.total_reward = 0.0 + agent.reset_agent_for_episode() def close(self) -> None: """Close the game, this will close the simulation.""" @@ -228,31 +230,7 @@ class PrimaiteGame: operating_state=NodeOperatingState.ON, ) elif n_type == "router": - new_node = Router( - hostname=node_cfg["hostname"], - num_ports=node_cfg.get("num_ports"), - operating_state=NodeOperatingState.ON, - ) - if "ports" in node_cfg: - for port_num, port_cfg in node_cfg["ports"].items(): - new_node.configure_port( - port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"] - ) - # new_node.enable_port(port_num) - if "acl" in node_cfg: - for r_num, r_cfg in node_cfg["acl"].items(): - # excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating - # this: 'r_cfg.get('src_port')' - # Port/IPProtocol. TODO Refactor - new_node.acl.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("ip_address"), - dst_ip_address=r_cfg.get("ip_address"), - position=r_num, - ) + new_node = Router.from_config(node_cfg) else: _LOGGER.warning(f"invalid node type {n_type} in config") if "services" in node_cfg: @@ -389,6 +367,7 @@ class PrimaiteGame: action_space=action_space, observation_space=obs_space, reward_function=rew_function, + agent_settings=agent_settings, ) game.agents.append(new_agent) game.rl_agents.append(new_agent) diff --git a/src/primaite/notebooks/_package_data/uc2_attack.png b/src/primaite/notebooks/_package_data/uc2_attack.png new file mode 100644 index 00000000..8b8df5ce Binary files /dev/null and b/src/primaite/notebooks/_package_data/uc2_attack.png differ diff --git a/src/primaite/notebooks/_package_data/uc2_network.png b/src/primaite/notebooks/_package_data/uc2_network.png new file mode 100644 index 00000000..20fa43c9 Binary files /dev/null and b/src/primaite/notebooks/_package_data/uc2_network.png differ diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index a89b29e4..ea006ae9 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -39,6 +39,15 @@ "#### Create a Ray algorithm and pass it our config." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(cfg['agents'][2]['agent_settings'])" + ] + }, { "cell_type": "code", "execution_count": null, @@ -76,6 +85,13 @@ " param_space=config\n", ").fit()\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 3950ef10..679e8226 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -1,30 +1,337 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Manipulation Scenario\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scenario\n", + "\n", + "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n", + "\n", + "[](_package_data/uc2_network.png)\n", + "\n", + "_(click image to enlarge)_\n", + "\n", + "The red agent deletes the contents of the database. When this happens, the web app cannot fetch data and users navigating to the website get a 404 error.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network\n", + "\n", + "- The web server has:\n", + " - a web service that replies to user HTTP requests\n", + " - a database client that fetches data for the web service\n", + "- The database server has:\n", + " - a POSTGRES database service\n", + " - a database file which is accessed by the database service\n", + " - FTP client used for backing up the data to the backup_server\n", + "- The backup server has:\n", + " - a copy of the database file in a known good state\n", + " - FTP server that can send the backed up file back to the database server\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Green agent\n", + "\n", + "The green agent is logged onto client 2. It sometimes uses the web browser on client 2 to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Red agent\n", + "\n", + "The red agent waits a bit then sends a DELETE query to the database from client 1. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n", + "\n", + "[](_package_data/uc2_attack.png)\n", + "\n", + "_(click image to enlarge)_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Blue agent\n", + "\n", + "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking client 1 from reaching the database server. This can be done by removing client 1's network connection or adding ACL rules on the router to stop the packets from arriving." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reinforcement learning details" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scripted agents:\n", + "### Red\n", + "The red agent sits on client 1 and uses an application called DataManipulationBot whose sole purpose is to send a DELETE query to the database.\n", + "The red agent can choose one of two action each timestep:\n", + "1. do nothing\n", + "2. execute the data manipulation application\n", + "The schedule for selecting when to execute the application is controlled by three parameters:\n", + "- start time\n", + "- frequency\n", + "- variance\n", + "Attacks start at a random timestep between (start_time - variance) and (start_time + variance). After each attack, another is attempted after a random delay between (frequency - variance) and (frequency + variance) timesteps.\n", + "\n", + "The data manipulation app itself has an element of randomness because the attack has a probability of success. The default is 0.8 to succeed with the port scan step and 0.8 to succeed with the attack itself.\n", + "Upon a successful attack, the database file becomes corrupted which incurs a negative reward for the RL defender.\n", + "\n", + "The red agent does not use information about the state of the network to decide its action.\n", + "\n", + "### Green\n", + "The green agent sits on client 2 and uses the web browser application to send requests to the web server. The schedule of the green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n", + "\n", + "When the green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Observation Space\n", + "\n", + "The blue agent's observation space is structured as nested dictionary with the following information:\n", + "```\n", + "\n", + "- NODES\n", + " - \n", + " - SERVICES\n", + " - \n", + " - operating_status\n", + " - health_status\n", + " - FOLDERS\n", + " - \n", + " - health_status\n", + " - FILES\n", + " - \n", + " - health_status\n", + " - NICS\n", + " - \n", + " - nic_status\n", + " - operating_status\n", + "- LINKS\n", + " - \n", + " - PROTOCOLS\n", + " - ALL\n", + " - load\n", + "- ACL\n", + " - \n", + " - position\n", + " - permission\n", + " - source_node_id\n", + " - source_port\n", + " - dest_node_id\n", + " - dest_port\n", + " - protocol\n", + "- ICS\n", + "```\n", + "\n", + "### Mappings\n", + "\n", + "The dict keys for `node_id` are in the following order:\n", + "|node_id|node name|\n", + "|--|--|\n", + "|1|domain_controller|\n", + "|2|web_server|\n", + "|3|database_server|\n", + "|4|backup_server|\n", + "|5|security_suite|\n", + "|6|client_1|\n", + "|7|client_2|\n", + "\n", + "Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", + "\n", + "Folder 1 on node 3 corresponds to the database folder. File 1 in that folder corresponds to the database storage file. Other files and folders are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n", + "\n", + "The dict keys for `link_id` are in the following order:\n", + "|link_id|endpoint_a|endpoint_b|\n", + "|--|--|--|\n", + "|1|router_1|switch_1|\n", + "|1|router_1|switch_2|\n", + "|1|switch_1|domain_controller|\n", + "|1|switch_1|web_server|\n", + "|1|switch_1|database_server|\n", + "|1|switch_1|backup_server|\n", + "|1|switch_1|security_suite|\n", + "|1|switch_2|client_1|\n", + "|1|switch_2|client_2|\n", + "|1|switch_2|security_suite|\n", + "\n", + "The ACL rules in the observation space appear in the same order that they do in the actual ACL. Though, only the first 10 rules are shown, there are default rules lower down that cannot be changed by the agent. The extra rules just allow the network to function normally, by allowing pings, ARP traffic, etc.\n", + "\n", + "Most nodes have only 1 nic, so the observation for those is placed at NIC index 1 in the observation space. Only the security suite has 2 NICs, the second NIC in the observation space is the one that connects the security suite with swtich_2.\n", + "\n", + "The meaning of the services' operating_state is:\n", + "|operating_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|RUNNING|\n", + "|2|STOPPED|\n", + "|3|PAUSED|\n", + "|4|DISABLED|\n", + "|5|INSTALLING|\n", + "|6|RESTARTING|\n", + "\n", + "The meaning of the services' health_state is:\n", + "|health_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|GOOD|\n", + "|2|PATCHING|\n", + "|3|COMPROMISED|\n", + "|4|OVERWHELMED|\n", + "\n", + "The meaning of the files' and folders' health_state is:\n", + "|health_state|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|GOOD|\n", + "|2|COMPROMISED|\n", + "|3|CORRUPT|\n", + "|4|RESTORING|\n", + "|5|REPAIRING|\n", + "\n", + "The meaning of the NICs' operating_status is:\n", + "|operating_status|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ENABLED|\n", + "|2|DISABLED|\n", + "\n", + "Link load has the following meaning:\n", + "|load|percent utilisation|\n", + "|--|--|\n", + "|0|exactly 0%|\n", + "|1|0-11%|\n", + "|2|11-22%|\n", + "|3|22-33%|\n", + "|4|33-44%|\n", + "|5|44-55%|\n", + "|6|55-66%|\n", + "|7|66-77%|\n", + "|8|77-88%|\n", + "|9|88-99%|\n", + "|10|exactly 100%|\n", + "\n", + "ACL permission has the following meaning:\n", + "|permission|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ALLOW|\n", + "|2|DENY|\n", + "\n", + "ACL source / destination node ids actually correspond to IP addresses (since ACLs work with IP addresses)\n", + "|source / dest node id|ip_address|label|\n", + "|--|--|--|\n", + "|0| | UNUSED|\n", + "|1| |ALL addresses|\n", + "|2| 192.168.1.10 | domain_controller|\n", + "|3| 192.168.1.12 | web_server \n", + "|4| 192.168.1.14 | database_server|\n", + "|5| 192.168.1.16 | backup_server|\n", + "|6| 192.168.1.110 | security_suite (eth-1)|\n", + "|7| 192.168.10.21 | client_1|\n", + "|8| 192.168.10.22 | client_2|\n", + "|9| 192.168.10.110| security_suite (eth-2)|\n", + "\n", + "ACL source / destination port ids have the following encoding:\n", + "|port id|port number| port use |\n", + "|--|--|--|\n", + "|0||UNUSED|\n", + "|1||ALL|\n", + "|2|219|ARP|\n", + "|3|53|DNS|\n", + "|4|80|HTTP|\n", + "|5|5432|POSTGRES_SERVER|\n", + "\n", + "ACL protocol ids have the following encoding:\n", + "|protocol id|label|\n", + "|--|--|\n", + "|0|UNUSED|\n", + "|1|ALL|\n", + "|2|ICMP|\n", + "|3|TCP|\n", + "|4|UDP|\n", + "\n", + "protocol" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Action Space\n", + "\n", + "The blue agent chooses from a list of 54 pre-defined actions. The full list is defined in the `action_map` in the config. The most important ones are explained here:\n", + "\n", + "- `0`: Do nothing\n", + "- `1`: Scan the web service - this refreshes the health status in the observation space\n", + "- `9`: Scan the database file - this refreshes the health status of the database file\n", + "- `13`: Patch the database service - This triggers the database to restore data from the backup server\n", + "- `19`: Shut down client 1\n", + "- `22`: Block outgoing traffic from client 1\n", + "- `26`: Block TCP traffic from client 1 to the database node\n", + "- `28-37`: Remove ACL rules 1-10\n", + "- `42`: Disconnect client 1 from the network\n", + "\n", + "The other actions will either have no effect or will negatively impact the network, so the blue agent should avoid taking other actions, and learn about these actions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reward Function\n", + "\n", + "The blue agent's reward is calculated using two measures:\n", + "1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n", + "2. Whether the green agent's most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n", + "These two components are averaged to get the final reward.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demonstration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, load the required modules" + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" - ] - } - ], + "outputs": [], "source": [ - "from primaite.session.session import PrimaiteSession\n", - "from primaite.game.game import PrimaiteGame\n", - "from primaite.config.load import example_config_path\n", - "\n", - "from primaite.simulator.system.services.database.database_service import DatabaseService\n", - "\n", - "import yaml" + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -36,61 +343,182 @@ "name": "stderr", "output_type": "stream", "text": [ - "2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n", - "2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "installing DNSServer on node domain_controller\n", - "installing DatabaseClient on node web_server\n", - "installing WebServer on node web_server\n", - "installing DatabaseService on node database_server\n", - "installing FTPClient on node database_server\n", - "installing FTPServer on node backup_server\n", - "installing DNSClient on node client_1\n", - "installing DNSClient on node client_2\n" + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2024-01-25 14:43:32,056\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2024-01-25 14:43:35,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] } ], "source": [ + "# Imports\n", + "from primaite.config.load import example_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from pprint import pprint\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instantiate the environment. We also disable the agent observation flattening.\n", "\n", - "with open(example_config_path(),'r') as cfgfile:\n", - " cfg = yaml.safe_load(cfgfile)\n", - "game = PrimaiteGame.from_config(cfg)\n", - "net = game.simulation.network\n", - "database_server = net.get_node_by_hostname('database_server')\n", - "web_server = net.get_node_by_hostname('web_server')\n", - "client_1 = net.get_node_by_hostname('client_1')\n", - "\n", - "db_service = database_server.software_manager.software[\"DatabaseService\"]\n", - "db_client = web_server.software_manager.software[\"DatabaseClient\"]\n", - "# db_client.run()\n", - "db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n", - "db_manipulation_bot.port_scan_p_of_success=1.0\n", - "db_manipulation_bot.data_manipulation_p_of_success=1.0\n" + "This cell will print the observation when the network is healthy. You should be able to verify Node file and service statuses against the description above." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting environment, episode 0, avg. reward: 0.0\n", + "env created successfully\n", + "{'ACL': {1: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 0,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 2: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 1,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 3: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 2,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 4: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 3,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 5: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 4,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 6: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 5,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 7: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 6,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 8: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 7,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 9: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 8,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 10: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 9,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0}},\n", + " 'ICS': 0,\n", + " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", + " 2: {'PROTOCOLS': {'ALL': 1}},\n", + " 3: {'PROTOCOLS': {'ALL': 1}},\n", + " 4: {'PROTOCOLS': {'ALL': 1}},\n", + " 5: {'PROTOCOLS': {'ALL': 1}},\n", + " 6: {'PROTOCOLS': {'ALL': 1}},\n", + " 7: {'PROTOCOLS': {'ALL': 1}},\n", + " 8: {'PROTOCOLS': {'ALL': 1}},\n", + " 9: {'PROTOCOLS': {'ALL': 1}},\n", + " 10: {'PROTOCOLS': {'ALL': 1}}},\n", + " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", + " 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}}\n" + ] + } + ], "source": [ - "db_client.run()" + "# create the env\n", + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " # set success probability to 1.0 to avoid rerunning cells.\n", + " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", + " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", + "game = PrimaiteGame.from_config(cfg)\n", + "env = PrimaiteGymEnv(game = game)\n", + "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", + "env.agent.flatten_obs = False\n", + "obs, info = env.reset()\n", + "print('env created successfully')\n", + "pprint(obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will go from 1.0 to 0.0, and to -1.0 when the green agent tries to access the webpage." ] }, { @@ -99,18 +527,55 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 1, Red action: DONOTHING, Blue reward:0.5\n", + "step: 2, Red action: DONOTHING, Blue reward:0.5\n", + "step: 3, Red action: DONOTHING, Blue reward:0.5\n", + "step: 4, Red action: DONOTHING, Blue reward:0.5\n", + "step: 5, Red action: DONOTHING, Blue reward:1.0\n", + "step: 6, Red action: DONOTHING, Blue reward:1.0\n", + "step: 7, Red action: DONOTHING, Blue reward:1.0\n", + "step: 8, Red action: DONOTHING, Blue reward:1.0\n", + "step: 9, Red action: DONOTHING, Blue reward:1.0\n", + "step: 10, Red action: DONOTHING, Blue reward:1.0\n", + "step: 11, Red action: DONOTHING, Blue reward:1.0\n", + "step: 12, Red action: DONOTHING, Blue reward:1.0\n", + "step: 13, Red action: DONOTHING, Blue reward:1.0\n", + "step: 14, Red action: DONOTHING, Blue reward:1.0\n", + "step: 15, Red action: DONOTHING, Blue reward:1.0\n", + "step: 16, Red action: DONOTHING, Blue reward:1.0\n", + "step: 17, Red action: DONOTHING, Blue reward:1.0\n", + "step: 18, Red action: DONOTHING, Blue reward:1.0\n", + "step: 19, Red action: DONOTHING, Blue reward:1.0\n", + "step: 20, Red action: DONOTHING, Blue reward:1.0\n", + "step: 21, Red action: DONOTHING, Blue reward:1.0\n", + "step: 22, Red action: NODE_APPLICATION_EXECUTE, Blue reward:0.0\n", + "step: 23, Red action: DONOTHING, Blue reward:0.0\n", + "step: 24, Red action: DONOTHING, Blue reward:0.0\n", + "step: 25, Red action: DONOTHING, Blue reward:0.0\n", + "step: 26, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 27, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 28, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 29, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 30, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 31, Red action: DONOTHING, Blue reward:-1.0\n", + "step: 32, Red action: DONOTHING, Blue reward:-1.0\n" + ] } ], "source": [ - "db_service.backup_database()" + "for step in range(32):\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the reward is -1, let's have a look at blue agent's observation." ] }, { @@ -119,27 +584,110 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}}, 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "pprint(obs['NODES'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The true statuses of the database file and webapp are not updated. The blue agent needs to perform a scan to see that they have degraded." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 3, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 2}}, 'health_status': 1}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NICS': {1: {'nic_status': 1}, 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] + } + ], "source": [ - "db_manipulation_bot.run()" + "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", + "obs, reward, terminated, truncated, info = env.step(1) # scan webapp service\n", + "pprint(obs['NODES'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now service 1 on node 2 has `health_status = 3`, indicating that the webapp is compromised.\n", + "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The blue agent can now patch the database to restore the file to a good health status." ] }, { @@ -148,18 +696,33 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 35\n", + "Red action: DONOTHING\n", + "Green action: NODE_APPLICATION_EXECUTE\n", + "Blue reward:-1.0\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", + "print(f\"step: {env.game.step_counter}\")\n", + "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The patching takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n", + "\n", + "The reward will be 0 as soon as the file finishes restoring. Then, the reward will increase to 1 when the green agent makes a request. (Because the webapp access part of the reward does not update until a successful request is made.)\n", + "\n", + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -168,18 +731,29 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 36\n", + "Red action: DONOTHING\n", + "Green action: NODE_APPLICATION_EXECUTE\n", + "Blue reward:0.0\n" + ] } ], "source": [ - "db_service.restore_backup()" + "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", + "print(f\"step: {env.game.step_counter}\")\n", + "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Blue reward:{reward}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The blue agent can prevent attacks by implementing an ACL rule to stop client_1 from sending POSTGRES traffic to the database. (Let's also patch the database file to get the reward back up.)" ] }, { @@ -188,90 +762,157 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 37, Red action: DONOTHING, Blue reward:0.0\n", + "step: 38, Red action: DONOTHING, Blue reward:0.0\n", + "step: 39, Red action: DONOTHING, Blue reward:1.0\n", + "step: 40, Red action: DONOTHING, Blue reward:1.0\n", + "step: 41, Red action: DONOTHING, Blue reward:1.0\n", + "step: 42, Red action: DONOTHING, Blue reward:1.0\n", + "step: 43, Red action: DONOTHING, Blue reward:1.0\n", + "step: 44, Red action: DONOTHING, Blue reward:1.0\n", + "step: 45, Red action: DONOTHING, Blue reward:1.0\n", + "step: 46, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.0\n", + "step: 47, Red action: DONOTHING, Blue reward:1.0\n", + "step: 48, Red action: DONOTHING, Blue reward:1.0\n", + "step: 49, Red action: DONOTHING, Blue reward:1.0\n", + "step: 50, Red action: DONOTHING, Blue reward:1.0\n", + "step: 51, Red action: DONOTHING, Blue reward:1.0\n", + "step: 52, Red action: DONOTHING, Blue reward:1.0\n", + "step: 53, Red action: DONOTHING, Blue reward:1.0\n", + "step: 54, Red action: DONOTHING, Blue reward:1.0\n", + "step: 55, Red action: DONOTHING, Blue reward:1.0\n", + "step: 56, Red action: DONOTHING, Blue reward:1.0\n", + "step: 57, Red action: DONOTHING, Blue reward:1.0\n", + "step: 58, Red action: DONOTHING, Blue reward:1.0\n", + "step: 59, Red action: DONOTHING, Blue reward:1.0\n", + "step: 60, Red action: DONOTHING, Blue reward:1.0\n", + "step: 61, Red action: DONOTHING, Blue reward:1.0\n", + "step: 62, Red action: DONOTHING, Blue reward:1.0\n", + "step: 63, Red action: DONOTHING, Blue reward:1.0\n", + "step: 64, Red action: DONOTHING, Blue reward:1.0\n", + "step: 65, Red action: DONOTHING, Blue reward:1.0\n", + "step: 66, Red action: DONOTHING, Blue reward:1.0\n", + "step: 67, Red action: DONOTHING, Blue reward:1.0\n", + "step: 68, Red action: DONOTHING, Blue reward:1.0\n" + ] } ], "source": [ - "db_client.query(\"SELECT\")" + "env.step(13) # Patch the database\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "\n", + "env.step(26) # Block client 1\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "\n", + "for step in range(30):\n", + " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "db_manipulation_bot.run()" + "Now, even though the red agent executes an attack, the reward stays at 1.0" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "client_1.ping(database_server.ethernet_port[1].ip_address)" + "Let's also have a look at the ACL observation to verify our new ACL rule at position 5." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, - "outputs": [], - "source": [ - "from pydantic import validate_call, BaseModel" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "class A(BaseModel):\n", - " x:int\n", - "\n", - " @validate_call\n", - " def increase_x(self, by:int) -> None:\n", - " self.x += 1" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "my_a = A(x=3)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, "outputs": [ { - "ename": "ValidationError", - "evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n", - "File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n", - "\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float" - ] + "data": { + "text/plain": [ + "{1: {'position': 0,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 2: {'position': 1,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 3: {'position': 2,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 4: {'position': 3,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 5: {'position': 4,\n", + " 'permission': 2,\n", + " 'source_node_id': 7,\n", + " 'source_port': 1,\n", + " 'dest_node_id': 4,\n", + " 'dest_port': 1,\n", + " 'protocol': 3},\n", + " 6: {'position': 5,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 7: {'position': 6,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 8: {'position': 7,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 9: {'position': 8,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 10: {'position': 9,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "my_a.increase_x(3.2)" + "obs['ACL']" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index ca71a0c0..a3831bc1 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -29,7 +29,7 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() @@ -39,7 +39,7 @@ class PrimaiteGymEnv(gymnasium.Env): reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() - info = {} + info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) return next_obs, reward, terminated, truncated, info @@ -81,13 +81,19 @@ class PrimaiteGymEnv(gymnasium.Env): @property def observation_space(self) -> gymnasium.Space: """Return the observation space of the environment.""" - return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + if self.agent.flatten_obs: + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + else: + return self.agent.observation_manager.space def _get_obs(self) -> ObsType: """Return the current observation.""" - unflat_space = self.agent.observation_manager.space - unflat_obs = self.agent.observation_manager.current_observation - return gymnasium.spaces.flatten(unflat_space, unflat_obs) + if not self.agent.flatten_obs: + return self.agent.observation_manager.current_observation + else: + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) class PrimaiteRayEnv(gymnasium.Env): @@ -166,7 +172,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -180,7 +186,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {} + infos = {"agent_actions": agent_actions} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 0d80a385..b4b740e9 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -24,9 +24,13 @@ class SessionIOSettings(BaseModel): save_transactions: bool = True """Whether to save transactions, If true, the session path will have a transactions folder.""" save_tensorboard_logs: bool = False - """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + """Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder.""" save_step_metadata: bool = False """Whether to save the RL agents' action, environment state, and other data at every single step.""" + save_pcap_logs: bool = False + """Whether to save PCAP logs.""" + save_sys_logs: bool = False + """Whether to save system logs.""" class SessionIO: @@ -39,9 +43,10 @@ class SessionIO: def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: self.settings: SessionIOSettings = settings self.session_path: Path = self.generate_session_path() - # set global SIM_OUTPUT path SIM_OUTPUT.path = self.session_path / "simulation_output" + SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs + SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's # possible refactor needed diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index ef462d83..5c663cfd 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -54,7 +54,7 @@ class PrimaiteSession: self.policy: PolicyABC """The reinforcement learning policy.""" - self.io_manager = SessionIO() + self.io_manager: Optional["SessionIO"] = None """IO manager for the session.""" self.game: PrimaiteGame = game @@ -101,9 +101,9 @@ class PrimaiteSession: # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"game": game}) + sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"game": game}) + sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) elif sess.training_options.rl_framework == "SB3": sess.env = PrimaiteGymEnv(game=game) diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index a4ee4c4d..bcf89b6a 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -44,3 +44,12 @@ def run(overwrite_existing: bool = True) -> None: print(dst_fp) shutil.copy2(src_fp, dst_fp) _LOGGER.info(f"Reset example notebook: {dst_fp}") + + for src_fp in primaite_root.glob("notebooks/_package_data/*"): + dst_fp = example_notebooks_user_dir / "_package_data" / src_fp.name + if should_copy_file(src_fp, dst_fp, overwrite_existing): + if not Path.exists(example_notebooks_user_dir / "_package_data/"): + Path.mkdir(example_notebooks_user_dir / "_package_data/") + print(dst_fp) + shutil.copy2(src_fp, dst_fp) + _LOGGER.info(f"Copied notebook resource to: {dst_fp}") diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 19c86e28..aebd77cf 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -7,11 +7,13 @@ from primaite import _PRIMAITE_ROOT __all__ = ["SIM_OUTPUT"] -class __SimOutput: +class _SimOutput: def __init__(self): self._path: Path = ( _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) + self.save_pcap_logs: bool = False + self.save_sys_logs: bool = False @property def path(self) -> Path: @@ -23,4 +25,4 @@ class __SimOutput: self._path.mkdir(exist_ok=True, parents=True) -SIM_OUTPUT = __SimOutput() +SIM_OUTPUT = _SimOutput() diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index f5e734cf..48ea587d 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -23,7 +23,6 @@ class FileSystem(SimComponent): "List containing all the folders in the file system." deleted_folders: Dict[str, Folder] = {} "List containing all the folders that have been deleted." - _folders_by_name: Dict[str, Folder] = {} sys_log: SysLog "Instance of SysLog used to create system logs." sim_root: Path @@ -56,7 +55,6 @@ class FileSystem(SimComponent): folder = self.deleted_folders[uuid] self.deleted_folders.pop(uuid) self.folders[uuid] = folder - self._folders_by_name[folder.name] = folder # Clear any other deleted folders that aren't original (have been created by agent) self.deleted_folders.clear() @@ -67,7 +65,6 @@ class FileSystem(SimComponent): if uuid not in original_folder_uuids: folder = self.folders[uuid] self.folders.pop(uuid) - self._folders_by_name.pop(folder.name) # Now reset all remaining folders for folder in self.folders.values(): @@ -173,7 +170,6 @@ class FileSystem(SimComponent): folder = Folder(name=folder_name, sys_log=self.sys_log) self.folders[folder.uuid] = folder - self._folders_by_name[folder.name] = folder self._folder_request_manager.add_request( name=folder.name, request_type=RequestType(func=folder._request_manager) ) @@ -188,14 +184,13 @@ class FileSystem(SimComponent): if folder_name == "root": self.sys_log.warning("Cannot delete the root folder.") return - folder = self._folders_by_name.get(folder_name) + folder = self.get_folder(folder_name) if folder: # set folder to deleted state folder.delete() # remove from folder list self.folders.pop(folder.uuid) - self._folders_by_name.pop(folder.name) # add to deleted list folder.remove_all_files() @@ -221,7 +216,10 @@ class FileSystem(SimComponent): :param folder_name: The folder name. :return: The matching Folder. """ - return self._folders_by_name.get(folder_name) + for folder in self.folders.values(): + if folder.name == folder_name: + return folder + return None def get_folder_by_id(self, folder_uuid: str, include_deleted: bool = False) -> Optional[Folder]: """ @@ -261,13 +259,13 @@ class FileSystem(SimComponent): """ if folder_name: # check if file with name already exists - folder = self._folders_by_name.get(folder_name) + folder = self.get_folder(folder_name) # If not then create it if not folder: folder = self.create_folder(folder_name) else: # Use root folder if folder_name not supplied - folder = self._folders_by_name["root"] + folder = self.get_folder("root") # Create the file and add it to the folder file = File( @@ -474,7 +472,6 @@ class FileSystem(SimComponent): folder.restore() self.folders[folder.uuid] = folder - self._folders_by_name[folder.name] = folder if folder.deleted: self.deleted_folders.pop(folder.uuid) diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index 86cd1ee7..c3e1426b 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -87,7 +87,7 @@ class FileSystemItemABC(SimComponent): def set_original_state(self): """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"} + vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} self._original_state = self.model_dump(include=vals_to_keep) def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index d4e72f63..dae32cd5 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -17,8 +17,6 @@ class Folder(FileSystemItemABC): files: Dict[str, File] = {} "Files stored in the folder." - _files_by_name: Dict[str, File] = {} - "Files by their name as .." deleted_files: Dict[str, File] = {} "Files that have been deleted." @@ -78,7 +76,6 @@ class Folder(FileSystemItemABC): file = self.deleted_files[uuid] self.deleted_files.pop(uuid) self.files[uuid] = file - self._files_by_name[file.name] = file # Clear any other deleted files that aren't original (have been created by agent) self.deleted_files.clear() @@ -89,7 +86,6 @@ class Folder(FileSystemItemABC): if uuid not in original_file_uuids: file = self.files[uuid] self.files.pop(uuid) - self._files_by_name.pop(file.name) # Now reset all remaining files for file in self.files.values(): @@ -105,7 +101,7 @@ class Folder(FileSystemItemABC): self._file_request_manager = RequestManager() rm.add_request( name="file", - request_type=RequestType(func=lambda request, context: self._file_request_manager), + request_type=RequestType(func=self._file_request_manager), ) return rm @@ -219,7 +215,10 @@ class Folder(FileSystemItemABC): :return: The matching File. """ # TODO: Increment read count? - return self._files_by_name.get(file_name) + for file in self.files.values(): + if file.name == file_name: + return file + return None def get_file_by_id(self, file_uuid: str, include_deleted: Optional[bool] = False) -> File: """ @@ -250,15 +249,14 @@ class Folder(FileSystemItemABC): raise Exception(f"Invalid file: {file}") # check if file with id or name already exists in folder - if (force is not True) and file.name in self._files_by_name: + if self.get_file(file.name) is not None and not force: raise Exception(f"File with name {file.name} already exists in folder") - if (force is not True) and file.uuid in self.files: + if (file.uuid in self.files) and not force: raise Exception(f"File with uuid {file.uuid} already exists in folder") # add to list self.files[file.uuid] = file - self._files_by_name[file.name] = file self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager)) file.folder = self @@ -275,11 +273,9 @@ class Folder(FileSystemItemABC): if self.files.get(file.uuid): self.files.pop(file.uuid) - self._files_by_name.pop(file.name) self.deleted_files[file.uuid] = file file.delete() self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})") - self._file_request_manager.remove_request(file.uuid) else: _LOGGER.debug(f"File with UUID {file.uuid} was not found.") @@ -300,7 +296,6 @@ class Folder(FileSystemItemABC): self.deleted_files[file_id] = file self.files = {} - self._files_by_name = {} def restore_file(self, file_uuid: str): """ @@ -316,7 +311,6 @@ class Folder(FileSystemItemABC): file.restore() self.files[file.uuid] = file - self._files_by_name[file.name] = file if file.deleted: self.deleted_files.pop(file_uuid) diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py new file mode 100644 index 00000000..48313a1f --- /dev/null +++ b/src/primaite/simulator/network/creation.py @@ -0,0 +1,148 @@ +from ipaddress import IPv4Address +from typing import Optional + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port + + +def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int: + """ + Calculate the minimum number of network switches required to connect a given number of nodes. + + Each switch is assumed to have one port reserved for connecting to a router, reducing the effective + number of ports available for PCs. The function calculates the total number of switches needed + to accommodate all nodes under this constraint. + + :param num_nodes: The total number of nodes that need to be connected in the network. + :param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24. + + :return: The minimum number of switches required to connect all PCs. + + Example: + >>> num_of_switches_required(5) + 1 + >>> num_of_switches_required(24,24) + 2 + >>> num_of_switches_required(48,24) + 3 + >>> num_of_switches_required(25,10) + 3 + """ + # Reduce the effective number of switch ports by 1 to leave space for the router + effective_switch_ports = max_switch_ports - 1 + + # Calculate the number of fully utilised switches and any additional switch for remaining PCs + full_switches = num_nodes // effective_switch_ports + extra_pcs = num_nodes % effective_switch_ports + + # Return the total number of switches required + return full_switches + (1 if extra_pcs > 0 else 0) + + +def create_office_lan( + lan_name: str, + subnet_base: int, + pcs_ip_block_start: int, + num_pcs: int, + network: Optional[Network] = None, + include_router: bool = True, +) -> Network: + """ + Creates a 2-Tier or 3-Tier office local area network (LAN). + + The LAN is configured with a specified number of personal computers (PCs), optionally including a router, + and multiple edge switches to connect them. A core switch is added only if more than one edge switch is required. + The network topology involves edge switches connected either directly to the router in a 2-Tier setup or + to a core switch in a 3-Tier setup. If a router is included, it is connected to the core switch (if present) + and configured with basic access control list (ACL) rules. PCs are distributed across the edge switches. + + + :param str lan_name: The name to be assigned to the LAN. + :param int subnet_base: The subnet base number to be used in the IP addresses. + :param int pcs_ip_block_start: The starting block for assigning IP addresses to PCs. + :param int num_pcs: The number of PCs to be added to the LAN. + :param Optional[Network] network: The network to which the LAN components will be added. If None, a new network is + created. + :param bool include_router: Flag to determine if a router should be included in the LAN. Defaults to True. + :return: The network object with the LAN components added. + :raises ValueError: If pcs_ip_block_start is less than or equal to the number of required switches. + """ + # Initialise the network if not provided + if not network: + network = Network() + + # Calculate the required number of switches + num_of_switches = num_of_switches_required(num_nodes=num_pcs) + effective_switch_ports = 23 # One port less for router connection + if pcs_ip_block_start <= num_of_switches: + raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}") + + # Create a core switch if more than one edge switch is needed + if num_of_switches > 1: + core_switch = Switch(hostname=f"switch_core_{lan_name}", start_up_duration=0) + core_switch.power_on() + network.add_node(core_switch) + core_switch_port = 1 + + # Initialise the default gateway to None + default_gateway = None + + # Optionally include a router in the LAN + if include_router: + default_gateway = IPv4Address(f"192.168.{subnet_base}.1") + router = Router(hostname=f"router_{lan_name}", start_up_duration=0) + router.power_on() + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + network.add_node(router) + router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") + router.enable_port(1) + + # Initialise the first edge switch and connect to the router or core switch + switch_port = 0 + switch_n = 1 + switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0) + switch.power_on() + network.add_node(switch) + if num_of_switches > 1: + network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24]) + else: + network.connect(router.ethernet_ports[1], switch.switch_ports[24]) + + # Add PCs to the LAN and connect them to switches + for i in range(1, num_pcs + 1): + # Add a new edge switch if the current one is full + if switch_port == effective_switch_ports: + switch_n += 1 + switch_port = 0 + switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0) + switch.power_on() + network.add_node(switch) + # Connect the new switch to the router or core switch + if num_of_switches > 1: + core_switch_port += 1 + network.connect(core_switch.switch_ports[core_switch_port], switch.switch_ports[24]) + else: + network.connect(router.ethernet_ports[1], switch.switch_ports[24]) + + # Create and add a PC to the network + pc = Computer( + hostname=f"pc_{i}_{lan_name}", + ip_address=f"192.168.{subnet_base}.{i+pcs_ip_block_start-1}", + subnet_mask="255.255.255.0", + default_gateway=default_gateway, + start_up_duration=0, + ) + pc.power_on() + network.add_node(pc) + + # Connect the PC to the switch + switch_port += 1 + network.connect(switch.switch_ports[switch_port], pc.ethernet_port[1]) + switch.switch_ports[switch_port].enable() + + return network diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index bbcdfe37..8e4c1d76 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -4,7 +4,7 @@ import re import secrets from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable @@ -274,18 +274,40 @@ class NIC(SimComponent): def receive_frame(self, frame: Frame) -> bool: """ - Receive a network frame from the connected link if the NIC is enabled. + Receive a network frame from the connected link, processing it if the NIC is enabled. - The Frame is passed to the Node. + This method decrements the Time To Live (TTL) of the frame, captures it using PCAP (Packet Capture), and checks + if the frame is either a broadcast or destined for this NIC. If the frame is acceptable, it is passed to the + connected node. The method also handles the discarding of frames with TTL expired and logs this event. - :param frame: The network frame being received. + The frame's reception is based on various conditions: + - If the NIC is disabled, the frame is not processed. + - If the TTL of the frame reaches zero after decrement, it is discarded and logged. + - If the frame is a broadcast or its destination MAC/IP address matches this NIC's, it is accepted. + - All other frames are dropped and logged or printed to the console. + + :param frame: The network frame being received. This should be an instance of the Frame class. + :return: Returns True if the frame is processed and passed to the node, False otherwise. """ if self.enabled: frame.decrement_ttl() + if frame.ip and frame.ip.ttl < 1: + self._connected_node.sys_log.info("Frame discarded as TTL limit reached") + return False frame.set_received_timestamp() self.pcap.capture(frame) # If this destination or is broadcast - if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff": + accept_frame = False + + # Check if it's a broadcast: + if frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff": + if frame.ip.dst_ip_address in {self.ip_address, self.ip_network.broadcast_address}: + accept_frame = True + else: + if frame.ethernet.dst_mac_addr == self.mac_address: + accept_frame = True + + if accept_frame: self._connected_node.receive_frame(frame=frame, from_nic=self) return True return False @@ -436,6 +458,9 @@ class SwitchPort(SimComponent): """ if self.enabled: frame.decrement_ttl() + if frame.ip and frame.ip.ttl < 1: + self._connected_node.sys_log.info("Frame discarded as TTL limit reached") + return False self.pcap.capture(frame) connected_node: Node = self._connected_node connected_node.forward_frame(frame=frame, incoming_port=self) @@ -671,17 +696,30 @@ class ARPCache: """Clear the entire ARP cache, removing all stored entries.""" self.arp.clear() - def send_arp_request(self, target_ip_address: Union[IPv4Address, str]): + def send_arp_request( + self, target_ip_address: Union[IPv4Address, str], ignore_networks: Optional[List[IPv4Address]] = None + ): """ Perform a standard ARP request for a given target IP address. Broadcasts the request through all enabled NICs to determine the MAC address corresponding to the target IP - address. + address. This method can be configured to ignore specific networks when sending out ARP requests, + which is useful in environments where certain addresses should not be queried. :param target_ip_address: The target IP address to send an ARP request for. + :param ignore_networks: An optional list of IPv4 addresses representing networks to be excluded from the ARP + request broadcast. Each address in this list indicates a network which will not be queried during the ARP + request process. This is particularly useful in complex network environments where traffic should be + minimized or controlled to specific subnets. It is mainly used by the router to prevent ARP requests being + sent back to their source. """ for nic in self.nics.values(): - if nic.enabled: + use_nic = True + if ignore_networks: + for ipv4 in ignore_networks: + if ipv4 in nic.ip_network: + use_nic = False + if nic.enabled and use_nic: self.sys_log.info(f"Sending ARP request from NIC {nic} for ip {target_ip_address}") tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) @@ -806,7 +844,6 @@ class ICMP: self.arp.send_arp_request(frame.ip.src_ip_address) self.process_icmp(frame=frame, from_nic=from_nic, is_reattempt=True) return - tcp_header = TCPHeader(src_port=Port.ARP, dst_port=Port.ARP) # Network Layer ip_packet = IPPacket( @@ -821,9 +858,7 @@ class ICMP: sequence=frame.icmp.sequence + 1, ) payload = secrets.token_urlsafe(int(32 / 1.3)) # Standard ICMP 32 bytes size - frame = Frame( - ethernet=ethernet_header, ip=ip_packet, tcp=tcp_header, icmp=icmp_reply_packet, payload=payload - ) + frame = Frame(ethernet=ethernet_header, ip=ip_packet, icmp=icmp_reply_packet, payload=payload) self.sys_log.info(f"Sending echo reply to {frame.ip.dst_ip_address}") src_nic.send_frame(frame) @@ -1275,8 +1310,8 @@ class Node(SimComponent): self.start_up_countdown = self.start_up_duration if self.start_up_duration <= 0: - self._start_up_actions() self.operating_state = NodeOperatingState.ON + self._start_up_actions() self.sys_log.info("Turned on") for nic in self.nics.values(): if nic._connected_link: @@ -1450,7 +1485,7 @@ class Node(SimComponent): service.parent = self service.install() # Perform any additional setup, such as creating files for this service on the node. self.sys_log.info(f"Installed service {service.name}") - _LOGGER.info(f"Added service {service.name} to node {self.hostname}") + _LOGGER.debug(f"Added service {service.name} to node {self.hostname}") self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager)) def uninstall_service(self, service: Service) -> None: @@ -1485,7 +1520,7 @@ class Node(SimComponent): self.applications[application.uuid] = application application.parent = self self.sys_log.info(f"Installed application {application.name}") - _LOGGER.info(f"Added application {application.name} to node {self.hostname}") + _LOGGER.debug(f"Added application {application.name} to node {self.hostname}") self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager)) def uninstall_application(self, application: Application) -> None: diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 0e6bc946..9a34be0b 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -9,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader @@ -18,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog class ACLAction(Enum): """Enum for defining the ACL action types.""" - DENY = 0 PERMIT = 1 + DENY = 2 class ACLRule(SimComponent): @@ -65,11 +66,11 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.value if self.protocol else None + state["protocol"] = self.protocol.name if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None - state["src_port"] = self.src_port.value if self.src_port else None + state["src_port"] = self.src_port.name if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None - state["dst_port"] = self.dst_port.value if self.dst_port else None + state["dst_port"] = self.dst_port.name if self.dst_port else None return state @@ -89,6 +90,8 @@ class AccessControlList(SimComponent): implicit_rule: ACLRule max_acl_rules: int = 25 _acl: List[Optional[ACLRule]] = [None] * 24 + _default_config: Dict[int, dict] = {} + """Config dict describing how the ACL list should look at episode start""" def __init__(self, **kwargs) -> None: if not kwargs.get("implicit_action"): @@ -106,10 +109,40 @@ class AccessControlList(SimComponent): vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + for i, rule in enumerate(self._acl): + if not rule: + continue + self._default_config[i] = {"action": rule.action.name} + if rule.src_ip_address: + self._default_config[i]["src_ip"] = str(rule.src_ip_address) + if rule.dst_ip_address: + self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) + if rule.src_port: + self._default_config[i]["src_port"] = rule.src_port.name + if rule.dst_port: + self._default_config[i]["dst_port"] = rule.dst_port.name + if rule.protocol: + self._default_config[i]["protocol"] = rule.protocol.name + def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.implicit_rule.reset_component_for_episode(episode) super().reset_component_for_episode(episode) + self._reset_rules_to_default() + + def _reset_rules_to_default(self) -> None: + """Clear all ACL rules and set them to the default rules config.""" + self._acl = [None] * (self.max_acl_rules - 1) + for r_num, r_cfg in self._default_config.items(): + self.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -129,9 +162,9 @@ class AccessControlList(SimComponent): func=lambda request, context: self.add_rule( ACLAction[request[0]], None if request[1] == "ALL" else IPProtocol[request[1]], - IPv4Address(request[2]), + None if request[2] == "ALL" else IPv4Address(request[2]), None if request[3] == "ALL" else Port[request[3]], - IPv4Address(request[4]), + None if request[4] == "ALL" else IPv4Address(request[4]), None if request[5] == "ALL" else Port[request[5]], int(request[6]), ) @@ -333,11 +366,10 @@ class RouteEntry(SimComponent): """ Represents a single entry in a routing table. - Attributes: - address (IPv4Address): The destination IP address or network address. - subnet_mask (IPv4Address): The subnet mask for the network. - next_hop_ip_address (IPv4Address): The next hop IP address to which packets should be forwarded. - metric (int): The cost metric for this route. Default is 0.0. + :ivar address: The destination IP address or network address. + :ivar subnet_mask: The subnet mask for the network. + :ivar next_hop_ip_address: The next hop IP address to which packets should be forwarded. + :ivar metric: The cost metric for this route. Default is 0.0. Example: >>> entry = RouteEntry( @@ -357,12 +389,6 @@ class RouteEntry(SimComponent): metric: float = 0.0 "The cost metric for this route. Default is 0.0." - def __init__(self, **kwargs): - for key in {"address", "subnet_mask", "next_hop_ip_address"}: - if not isinstance(kwargs[key], IPv4Address): - kwargs[key] = IPv4Address(kwargs[key]) - super().__init__(**kwargs) - def set_original_state(self): """Sets the original state.""" vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} @@ -397,10 +423,10 @@ class RouteTable(SimComponent): """ routes: List[RouteEntry] = [] + default_route: Optional[RouteEntry] = None sys_log: SysLog def set_original_state(self): - """Sets the original state.""" """Sets the original state.""" super().set_original_state() self._original_state["routes_orig"] = self.routes @@ -442,12 +468,35 @@ class RouteTable(SimComponent): ) self.routes.append(route) + def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address): + """ + Sets the next-hop IP address for the default route in a routing table. + + This method checks if a default route (0.0.0.0/0) exists in the routing table. If it does not exist, + the method creates a new default route with the specified next-hop IP address. If a default route already + exists, it updates the next-hop IP address of the existing default route. After setting the next-hop + IP address, the method logs this action. + + :param ip_address: The next-hop IP address to be set for the default route. + """ + if not self.default_route: + self.default_route = RouteEntry( + ip_address=IPv4Address("0.0.0.0"), + subnet_mask=IPv4Address("0.0.0.0"), + next_hop_ip_address=ip_address, + ) + else: + self.default_route.next_hop_ip_address = ip_address + self.sys_log.info(f"Default configured to use {ip_address} as the next-hop") + def find_best_route(self, destination_ip: Union[str, IPv4Address]) -> Optional[RouteEntry]: """ Find the best route for a given destination IP. This method uses the Longest Prefix Match algorithm and considers metrics to find the best route. + If no dedicated route exists but a default route does, then the default route is returned as a last resort. + :param destination_ip: The destination IP to find the route for. :return: The best matching RouteEntry, or None if no route matches. """ @@ -467,6 +516,9 @@ class RouteTable(SimComponent): longest_prefix = prefix_len lowest_metric = route.metric + if not best_route and self.default_route: + best_route = self.default_route + return best_route def show(self, markdown: bool = False): @@ -498,12 +550,26 @@ class RouterARPCache(ARPCache): super().__init__(sys_log) self.router: Router = router - def process_arp_packet(self, from_nic: NIC, frame: Frame): + def process_arp_packet( + self, from_nic: NIC, frame: Frame, route_table: RouteTable, is_reattempt: bool = False + ) -> None: """ - Overridden method to process a received ARP packet in a router-specific way. + Processes a received ARP (Address Resolution Protocol) packet in a router-specific way. + + This method is responsible for handling both ARP requests and responses. It processes ARP packets received on a + Network Interface Card (NIC) and performs actions based on whether the packet is a request or a reply. This + includes updating the ARP cache, forwarding ARP replies, sending ARP requests for unknown destinations, and + handling packet TTL (Time To Live). + + The method first checks if the ARP packet is a request or a reply. For ARP replies, it updates the ARP cache + and forwards the reply if necessary. For ARP requests, it checks if the target IP matches one of the router's + NICs and sends an ARP reply if so. If the destination is not directly connected, it consults the routing table + to find the best route and reattempts ARP request processing if needed. :param from_nic: The NIC that received the ARP packet. - :param frame: The original ARP frame. + :param frame: The frame containing the ARP packet. + :param route_table: The routing table of the router. + :param is_reattempt: Flag to indicate if this is a reattempt of processing the ARP packet, defaults to False. """ arp_packet = frame.arp @@ -531,7 +597,11 @@ class RouterARPCache(ARPCache): ) arp_packet.sender_mac_addr = nic.mac_address frame.decrement_ttl() + if frame.ip and frame.ip.ttl < 1: + self.sys_log.info("Frame discarded as TTL limit reached") + return nic.send_frame(frame) + return # ARP Request self.sys_log.info( @@ -542,16 +612,32 @@ class RouterARPCache(ARPCache): self.add_arp_cache_entry( ip_address=arp_packet.sender_ip_address, mac_address=arp_packet.sender_mac_addr, nic=from_nic ) - arp_packet = arp_packet.generate_reply(from_nic.mac_address) - self.send_arp_reply(arp_packet, from_nic) # If the target IP matches one of the router's NICs for nic in self.nics.values(): - if nic.enabled and nic.ip_address == arp_packet.target_ip_address: + if arp_packet.target_ip_address in nic.ip_network: + # if nic.enabled and nic.ip_address == arp_packet.target_ip_address: arp_reply = arp_packet.generate_reply(from_nic.mac_address) self.send_arp_reply(arp_reply, from_nic) return + # Check Route Table + route = route_table.find_best_route(arp_packet.target_ip_address) + if route: + nic = self.get_arp_cache_nic(route.next_hop_ip_address) + + if not nic: + if not is_reattempt: + self.send_arp_request(route.next_hop_ip_address, ignore_networks=[frame.ip.src_ip_address]) + return self.process_arp_packet(from_nic, frame, route_table, is_reattempt=True) + else: + self.sys_log.info("Ignoring ARP request as destination unavailable/No ARP entry found") + return + else: + arp_reply = arp_packet.generate_reply(from_nic.mac_address) + self.send_arp_reply(arp_reply, from_nic) + return + class RouterICMP(ICMP): """ @@ -622,7 +708,7 @@ class RouterICMP(ICMP): return # Route the frame - self.router.route_frame(frame, from_nic) + self.router.process_frame(frame, from_nic) elif frame.icmp.icmp_type == ICMPType.ECHO_REPLY: for nic in self.router.nics.values(): @@ -642,7 +728,48 @@ class RouterICMP(ICMP): return # Route the frame - self.router.route_frame(frame, from_nic) + self.router.process_frame(frame, from_nic) + + +class RouterNIC(NIC): + """ + A Router-specific Network Interface Card (NIC) that extends the standard NIC functionality. + + This class overrides the standard Node NIC's Layer 3 (L3) broadcast/unicast checks. It is designed + to handle network frames in a manner specific to routers, allowing them to efficiently process + and route network traffic. + """ + + def receive_frame(self, frame: Frame) -> bool: + """ + Receive and process a network frame from the connected link, provided the NIC is enabled. + + This method is tailored for router behavior. It decrements the frame's Time To Live (TTL), checks for TTL + expiration, and captures the frame using PCAP (Packet Capture). The frame is accepted if it is destined for + this NIC's MAC address or is a broadcast frame. + + Key Differences from Standard NIC: + - Does not perform Layer 3 (IP-based) broadcast checks. + - Only checks for Layer 2 (Ethernet) destination MAC address and broadcast frames. + + :param frame: The network frame being received. This should be an instance of the Frame class. + :return: Returns True if the frame is processed and passed to the connected node, False otherwise. + """ + if self.enabled: + frame.decrement_ttl() + if frame.ip and frame.ip.ttl < 1: + self._connected_node.sys_log.info("Frame discarded as TTL limit reached") + return False + frame.set_received_timestamp() + self.pcap.capture(frame) + # If this destination or is broadcast + if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff": + self._connected_node.receive_frame(frame=frame, from_nic=self) + return True + return False + + def __str__(self) -> str: + return f"{self.mac_address}/{self.ip_address}" class Router(Node): @@ -655,7 +782,7 @@ class Router(Node): """ num_ports: int - ethernet_ports: Dict[int, NIC] = {} + ethernet_ports: Dict[int, RouterNIC] = {} acl: AccessControlList route_table: RouteTable arp: RouterARPCache @@ -674,7 +801,7 @@ class Router(Node): kwargs["icmp"] = RouterICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"), router=self) super().__init__(hostname=hostname, num_ports=num_ports, **kwargs) for i in range(1, self.num_ports + 1): - nic = NIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") + nic = RouterNIC(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0") self.connect_nic(nic) self.ethernet_ports[i] = nic @@ -725,13 +852,13 @@ class Router(Node): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = (self.num_ports,) - state["acl"] = (self.acl.describe_state(),) + state["num_ports"] = self.num_ports + state["acl"] = self.acl.describe_state() return state - def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: + def process_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: """ - Route a given frame from a source NIC to its destination. + Process a Frame. :param frame: The frame to be routed. :param from_nic: The source network interface. @@ -746,25 +873,57 @@ class Router(Node): return if not nic: - self.arp.send_arp_request(frame.ip.dst_ip_address) - return self.route_frame(frame=frame, from_nic=from_nic, re_attempt=True) + self.arp.send_arp_request( + frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address, from_nic.ip_address] + ) + return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True) if not nic.enabled: - # TODO: Add sys_log here + self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled") return if frame.ip.dst_ip_address in nic.ip_network: from_port = self._get_port_of_nic(from_nic) to_port = self._get_port_of_nic(nic) - self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}") + self.sys_log.info(f"Forwarding frame to internally from port {from_port} to port {to_port}") frame.decrement_ttl() + if frame.ip and frame.ip.ttl < 1: + self.sys_log.info("Frame discarded as TTL limit reached") + return frame.ethernet.src_mac_addr = nic.mac_address frame.ethernet.dst_mac_addr = target_mac nic.send_frame(frame) return else: - pass - # TODO: Deal with routing from route tables + self._route_frame(frame, from_nic) + + def _route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: + route = self.route_table.find_best_route(frame.ip.dst_ip_address) + if route: + nic = self.arp.get_arp_cache_nic(route.next_hop_ip_address) + target_mac = self.arp.get_arp_cache_mac_address(route.next_hop_ip_address) + if re_attempt and not nic: + self.sys_log.info(f"Destination {frame.ip.dst_ip_address} is unreachable") + return + + if not nic: + self.arp.send_arp_request(frame.ip.dst_ip_address, ignore_networks=[frame.ip.src_ip_address]) + return self.process_frame(frame=frame, from_nic=from_nic, re_attempt=True) + + if not nic.enabled: + self.sys_log.info(f"Frame dropped as NIC {nic} is not enabled") + return + + from_port = self._get_port_of_nic(from_nic) + to_port = self._get_port_of_nic(nic) + self.sys_log.info(f"Routing frame to internally from port {from_port} to port {to_port}") + frame.decrement_ttl() + if frame.ip and frame.ip.ttl < 1: + self.sys_log.info("Frame discarded as TTL limit reached") + return + frame.ethernet.src_mac_addr = nic.mac_address + frame.ethernet.dst_mac_addr = target_mac + nic.send_frame(frame) def receive_frame(self, frame: Frame, from_nic: NIC): """ @@ -773,7 +932,7 @@ class Router(Node): :param frame: The incoming frame. :param from_nic: The network interface where the frame is coming from. """ - route_frame = False + process_frame = False protocol = frame.ip.protocol src_ip_address = frame.ip.src_ip_address dst_ip_address = frame.ip.dst_ip_address @@ -805,12 +964,12 @@ class Router(Node): self.icmp.process_icmp(frame=frame, from_nic=from_nic) else: if src_port == Port.ARP: - self.arp.process_arp_packet(from_nic=from_nic, frame=frame) + self.arp.process_arp_packet(from_nic=from_nic, frame=frame, route_table=self.route_table) else: # All other traffic - route_frame = True - if route_frame: - self.route_frame(frame, from_nic) + process_frame = True + if process_frame: + self.process_frame(frame, from_nic) def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]): """ @@ -873,3 +1032,63 @@ class Router(Node): ] ) print(table) + + @classmethod + def from_config(cls, cfg: dict) -> "Router": + """Create a router based on a config dict. + + Schema: + - hostname (str): unique name for this router. + - num_ports (int, optional): Number of network ports on the router. 8 by default + - ports (dict): Dict with integers from 1 - num_ports as keys. The values should be another dict specifying + ip_address and subnet_mask assigned to that ports (as strings) + - acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL + where the rule will be added (lower number is resolved first). The values should describe valid ACL + Rules as: + - action (str): either PERMIT or DENY + - src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER + - protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP + - src_ip_address (str, optional): IP address octet written in base 10 + - dst_ip_address (str, optional): IP address octet written in base 10 + + Example config: + ``` + { + 'hostname': 'router_1', + 'num_ports': 5, + 'ports': { + 1: { + 'ip_address' : '192.168.1.1', + 'subnet_mask' : '255.255.255.0', + } + }, + 'acl' : { + 21: {'action': 'PERMIT', 'src_port': 'HTTP', dst_port: 'HTTP'}, + 22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'}, + 23: {'action': 'PERMIT', 'protocol': 'ICMP'}, + }, + } + ``` + + :param cfg: Router config adhering to schema described in main docstring body + :type cfg: dict + :return: Configured router. + :rtype: Router + """ + new = Router( + hostname=cfg["hostname"], + num_ports=cfg.get("num_ports"), + operating_state=NodeOperatingState.ON, + ) + if "ports" in cfg: + for port_num, port_cfg in cfg["ports"].items(): + new.configure_port( + port=port_num, + ip_address=port_cfg["ip_address"], + subnet_mask=port_cfg["subnet_mask"], + ) + if "acl" in cfg: + new.acl._default_config = cfg["acl"] # save the config to allow resetting + new.acl._reset_rules_to_default() # read the config and apply rules + return new diff --git a/src/primaite/simulator/network/hardware/nodes/switch.py b/src/primaite/simulator/network/hardware/nodes/switch.py index fffae6e2..b394bae0 100644 --- a/src/primaite/simulator/network/hardware/nodes/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/switch.py @@ -90,12 +90,12 @@ class Switch(Node): self._add_mac_table_entry(src_mac, incoming_port) outgoing_port = self.mac_address_table.get(dst_mac) - if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff": + if outgoing_port and dst_mac.lower() != "ff:ff:ff:ff:ff:ff": outgoing_port.send_frame(frame) else: # If the destination MAC is not in the table, flood to all ports except incoming for port in self.switch_ports.values(): - if port != incoming_port: + if port.enabled and port != incoming_port: port.send_frame(frame) def disconnect_link_from_port(self, link: Link, port_number: int): diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 898e5917..322ac808 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,9 +38,6 @@ class Application(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - self.health_state_visible = SoftwareHealthState.UNUSED - self.health_state_actual = SoftwareHealthState.UNUSED - def set_original_state(self): """Sets the original state.""" super().set_original_state() @@ -95,6 +92,9 @@ class Application(IOSoftware): if self.operating_state == ApplicationOperatingState.CLOSED: self.sys_log.info(f"Running Application {self.name}") self.operating_state = ApplicationOperatingState.RUNNING + # set software health state to GOOD if initially set to UNUSED + if self.health_state_actual == SoftwareHealthState.UNUSED: + self.set_health_state(SoftwareHealthState.GOOD) def _application_loop(self): """The main application loop.""" diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index a1429e51..a844f059 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -72,7 +72,7 @@ class DataManipulationBot(DatabaseClient): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) + rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack())) return rm @@ -83,7 +83,7 @@ class DataManipulationBot(DatabaseClient): payload: Optional[str] = None, port_scan_p_of_success: float = 0.1, data_manipulation_p_of_success: float = 0.1, - repeat: bool = False, + repeat: bool = True, ): """ Configure the DataManipulatorBot to communicate with a DatabaseService. @@ -168,6 +168,12 @@ class DataManipulationBot(DatabaseClient): Calls the parent classes execute method before starting the application loop. """ super().run() + + def attack(self): + """Perform the attack steps after opening the application.""" + if not self._can_perform_action(): + _LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.") + self.run() self._application_loop() def _application_loop(self): @@ -198,4 +204,4 @@ class DataManipulationBot(DatabaseClient): :param timestep: The timestep value to update the bot's state. """ - self._application_loop() + pass diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index 1539e024..bfb6a055 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -41,6 +41,9 @@ class PacketCapture: def setup_logger(self): """Set up the logger configuration.""" + if not SIM_OUTPUT.save_pcap_logs: + return + log_path = self._get_log_path() file_handler = logging.FileHandler(filename=log_path) @@ -88,5 +91,6 @@ class PacketCapture: :param frame: The PCAP frame to capture. """ - msg = frame.model_dump_json() - self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL + if SIM_OUTPUT.save_pcap_logs: + msg = frame.model_dump_json() + self.logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 8658f155..a95846a3 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ipaddress import IPv4Address +from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable @@ -141,41 +141,76 @@ class SessionManager: def receive_payload_from_software_manager( self, payload: Any, - dst_ip_address: Optional[IPv4Address] = None, + dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dst_port: Optional[Port] = None, session_id: Optional[str] = None, is_reattempt: bool = False, ) -> Union[Any, None]: """ - Receive a payload from the SoftwareManager. + Receive a payload from the SoftwareManager and send it to the appropriate NIC for transmission. - If no session_id, a Session is established. Once established, the payload is sent to ``send_payload_to_nic``. + This method supports both unicast and Layer 3 broadcast transmissions. If `dst_ip_address` is an + IPv4Network, a broadcast is initiated. For unicast, the destination MAC address is resolved via ARP. + A new session is established if `session_id` is not provided, and an existing session is used otherwise. :param payload: The payload to be sent. - :param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created. + :param dst_ip_address: The destination IP address or network for broadcast. Optional. + :param dst_port: The destination port for the TCP packet. Optional. + :param session_id: The Session ID from which the payload originates. Optional. + :param is_reattempt: Flag to indicate if this is a reattempt after an ARP request. Default is False. + :return: The outcome of sending the frame, or None if sending was unsuccessful. """ + is_broadcast = False + outbound_nic = None + dst_mac_address = None + + # Use session details if session_id is provided if session_id: session = self.sessions_by_uuid[session_id] - dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address - dst_port = self.sessions_by_uuid[session_id].dst_port + dst_ip_address = session.with_ip_address + dst_port = session.dst_port - dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address) + # Determine if the payload is for broadcast or unicast - if dst_mac_address: - outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address) + # Handle broadcast transmission + if isinstance(dst_ip_address, IPv4Network): + is_broadcast = True + dst_ip_address = dst_ip_address.broadcast_address + if dst_ip_address: + # Find a suitable NIC for the broadcast + for nic in self.arp_cache.nics.values(): + if dst_ip_address in nic.ip_network and nic.enabled: + dst_mac_address = "ff:ff:ff:ff:ff:ff" + outbound_nic = nic else: - if not is_reattempt: - self.arp_cache.send_arp_request(dst_ip_address) - return self.receive_payload_from_software_manager( - payload=payload, - dst_ip_address=dst_ip_address, - dst_port=dst_port, - session_id=session_id, - is_reattempt=True, - ) - else: - return + # Resolve MAC address for unicast transmission + dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address) + # Resolve outbound NIC for unicast transmission + if dst_mac_address: + outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address) + + # If MAC address not found, initiate ARP request + else: + if not is_reattempt: + self.arp_cache.send_arp_request(dst_ip_address) + # Reattempt payload transmission after ARP request + return self.receive_payload_from_software_manager( + payload=payload, + dst_ip_address=dst_ip_address, + dst_port=dst_port, + session_id=session_id, + is_reattempt=True, + ) + else: + # Return None if reattempt fails + return + + # Check if outbound NIC and destination MAC address are resolved + if not outbound_nic or not dst_mac_address: + return False + + # Construct the frame for transmission frame = Frame( ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address), ip=IPPacket( @@ -189,15 +224,17 @@ class SessionManager: payload=payload, ) - if not session_id: + # Manage session for unicast transmission + if not (is_broadcast and session_id): session_key = self._get_session_key(frame, inbound_frame=False) session = self.sessions_by_key.get(session_key) if not session: - # Create new session + # Create a new session if it doesn't exist session = Session.from_session_key(session_key) self.sessions_by_key[session_key] = session self.sessions_by_uuid[session.uuid] = session + # Send the frame through the NIC return outbound_nic.send_frame(frame) def receive_frame(self, frame: Frame): diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 21a121c1..95948a1e 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -1,4 +1,4 @@ -from ipaddress import IPv4Address +from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable @@ -130,20 +130,28 @@ class SoftwareManager: def send_payload_to_session_manager( self, payload: Any, - dest_ip_address: Optional[IPv4Address] = None, + dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dest_port: Optional[Port] = None, session_id: Optional[str] = None, ) -> bool: """ - Send a payload to the SessionManager. + Sends a payload to the SessionManager for network transmission. + + This method is responsible for initiating the process of sending network payloads. It supports both + unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified + as an IPv4Network. :param payload: The payload to be sent. - :param dest_ip_address: The ip address of the payload destination. - :param dest_port: The port of the payload destination. - :param session_id: The Session ID the payload is to originate from. Optional. + :param dest_ip_address: The IP address or network (for broadcasts) of the payload destination. + :param dest_port: The destination port for the payload. Optional. + :param session_id: The Session ID from which the payload originates. Optional. + :return: True if the payload was successfully sent, False otherwise. """ return self.session_manager.receive_payload_from_software_manager( - payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id + payload=payload, + dst_ip_address=dest_ip_address, + dst_port=dest_port, + session_id=session_id, ) def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 41ce8fee..00e6920b 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -41,6 +41,9 @@ class SysLog: The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out JSON-like messages. """ + if not SIM_OUTPUT.save_sys_logs: + return + log_path = self._get_log_path() file_handler = logging.FileHandler(filename=log_path) file_handler.setLevel(logging.DEBUG) @@ -91,7 +94,8 @@ class SysLog: :param msg: The message to be logged. """ - self.logger.debug(msg) + if SIM_OUTPUT.save_sys_logs: + self.logger.debug(msg) def info(self, msg: str): """ @@ -99,7 +103,8 @@ class SysLog: :param msg: The message to be logged. """ - self.logger.info(msg) + if SIM_OUTPUT.save_sys_logs: + self.logger.info(msg) def warning(self, msg: str): """ @@ -107,7 +112,8 @@ class SysLog: :param msg: The message to be logged. """ - self.logger.warning(msg) + if SIM_OUTPUT.save_sys_logs: + self.logger.warning(msg) def error(self, msg: str): """ @@ -115,7 +121,8 @@ class SysLog: :param msg: The message to be logged. """ - self.logger.error(msg) + if SIM_OUTPUT.save_sys_logs: + self.logger.error(msg) def critical(self, msg: str): """ @@ -123,4 +130,5 @@ class SysLog: :param msg: The message to be logged. """ - self.logger.critical(msg) + if SIM_OUTPUT.save_sys_logs: + self.logger.critical(msg) diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 6f333091..c9c4d6fa 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -3,6 +3,8 @@ from typing import Any, Dict, List, Literal, Optional, Union from primaite import getLogger from primaite.simulator.file_system.file_system import File +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus +from primaite.simulator.file_system.folder import Folder from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager @@ -22,7 +24,7 @@ class DatabaseService(Service): password: Optional[str] = None - backup_server: IPv4Address = None + backup_server_ip: IPv4Address = None """IP address of the backup server.""" latest_backup_directory: str = None @@ -36,7 +38,6 @@ class DatabaseService(Service): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self._db_file: File self._create_db_file() def set_original_state(self): @@ -45,8 +46,8 @@ class DatabaseService(Service): super().set_original_state() vals_to_include = { "password", - "_connections", - "backup_server", + "connections", + "backup_server_ip", "latest_backup_directory", "latest_backup_file_name", } @@ -64,7 +65,7 @@ class DatabaseService(Service): :param: backup_server_ip: The IP address of the backup server """ - self.backup_server = backup_server + self.backup_server_ip = backup_server def backup_database(self) -> bool: """Create a backup of the database to the configured backup server.""" @@ -73,7 +74,7 @@ class DatabaseService(Service): return False # check if the backup server was configured - if self.backup_server is None: + if self.backup_server_ip is None: self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.") return False @@ -81,10 +82,14 @@ class DatabaseService(Service): ftp_client_service: FTPClient = software_manager.software.get("FTPClient") # send backup copy of database file to FTP server + if not self.db_file: + self.sys_log.error("Attempted to backup database file but it doesn't exist.") + return False + response = ftp_client_service.send_file( - dest_ip_address=self.backup_server, - src_file_name=self._db_file.name, - src_folder_name=self.folder.name, + dest_ip_address=self.backup_server_ip, + src_file_name=self.db_file.name, + src_folder_name="database", dest_folder_name=str(self.uuid), dest_file_name="database.db", ) @@ -110,7 +115,7 @@ class DatabaseService(Service): src_file_name="database.db", dest_folder_name="downloads", dest_file_name="database.db", - dest_ip_address=self.backup_server, + dest_ip_address=self.backup_server_ip, ) if not response: @@ -118,13 +123,10 @@ class DatabaseService(Service): return False # replace db file - self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db") - self.file_system.copy_file( - src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name - ) - self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db") + self.file_system.delete_file(folder_name="database", file_name="database.db") + self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database") - if self._db_file is None: + if self.db_file is None: self.sys_log.error("Copying database backup failed.") return False @@ -134,12 +136,30 @@ class DatabaseService(Service): def _create_db_file(self): """Creates the Simulation File and sqlite file in the file system.""" - self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db") - self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id) + self.file_system.create_file(folder_name="database", file_name="database.db") + + @property + def db_file(self) -> File: + """Returns the database file.""" + return self.file_system.get_file(folder_name="database", file_name="database.db") + + @property + def folder(self) -> Folder: + """Returns the database folder.""" + return self.file_system.get_folder_by_id(self.db_file.folder_id) def _process_connect( self, connection_id: str, password: Optional[str] = None ) -> Dict[str, Union[int, Dict[str, bool]]]: + """Process an incoming connection request. + + :param connection_id: A unique identifier for the connection + :type connection_id: str + :param password: Supplied password. It must match self.password for connection success, defaults to None + :type password: Optional[str], optional + :return: Response to connection request containing success info. + :rtype: Dict[str, Union[int, Dict[str, bool]]] + """ status_code = 500 # Default internal server error if self.operating_state == ServiceOperatingState.RUNNING: status_code = 503 # service unavailable @@ -184,7 +204,7 @@ class DatabaseService(Service): self.sys_log.info(f"{self.name}: Running {query}") if query == "SELECT": - if self.health_state_actual == SoftwareHealthState.GOOD: + if self.db_file.health_status == FileSystemItemHealthStatus.GOOD: return { "status_code": 200, "type": "sql", @@ -195,17 +215,8 @@ class DatabaseService(Service): else: return {"status_code": 404, "data": False} elif query == "DELETE": - if self.health_state_actual == SoftwareHealthState.GOOD: - self.health_state_actual = SoftwareHealthState.COMPROMISED - return { - "status_code": 200, - "type": "sql", - "data": False, - "uuid": query_id, - "connection_id": connection_id, - } - else: - return {"status_code": 404, "data": False} + self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED + return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id, "connection_id": connection_id} else: # Invalid query return {"status_code": 500, "data": False} @@ -265,3 +276,19 @@ class DatabaseService(Service): software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) return payload["status_code"] == 200 + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a single timestep of simulation dynamics to this service. + + Here at the first step, the database backup is created, in addition to normal service update logic. + """ + if timestep == 1: + self.backup_database() + return super().apply_timestep(timestep) + + def _update_patch_status(self) -> None: + """Perform a database restore when the patching countdown is finished.""" + super()._update_patch_status() + if self._patching_countdown is None: + self.restore_backup() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 7faa5d32..39bc57f0 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -89,6 +89,7 @@ class FTPClient(FTPServiceABC): f"{self.name}: Successfully connected to FTP Server " f"{dest_ip_address} via port {payload.ftp_command_args.value}" ) + self.add_connection(connection_id="server_connection", session_id=session_id) return True else: if is_reattempt: diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 585690b6..a82b0919 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -99,5 +99,5 @@ class FTPServer(FTPServiceABC): if payload.status_code is not None: return False - self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id) + self._process_ftp_command(payload=payload, session_id=session_id) return True diff --git a/src/primaite/simulator/system/services/ftp/ftp_service.py b/src/primaite/simulator/system/services/ftp/ftp_service.py index f2c01544..70ba74d7 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_service.py +++ b/src/primaite/simulator/system/services/ftp/ftp_service.py @@ -1,7 +1,7 @@ import shutil from abc import ABC from ipaddress import IPv4Address -from typing import Optional +from typing import Dict, Optional from primaite.simulator.file_system.file_system import File from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode @@ -16,6 +16,10 @@ class FTPServiceABC(Service, ABC): Contains shared methods between both classes. """ + def describe_state(self) -> Dict: + """Returns a Dict of the FTPService state.""" + return super().describe_state() + def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. @@ -52,10 +56,12 @@ class FTPServiceABC(Service, ABC): folder_name = payload.ftp_command_args["dest_folder_name"] file_size = payload.ftp_command_args["file_size"] real_file_path = payload.ftp_command_args.get("real_file_path") + health_status = payload.ftp_command_args["health_status"] is_real = real_file_path is not None file = self.file_system.create_file( file_name=file_name, folder_name=folder_name, size=file_size, real=is_real ) + file.health_status = health_status self.sys_log.info( f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/" f"{payload.ftp_command_args['dest_file_name']}" @@ -110,6 +116,7 @@ class FTPServiceABC(Service, ABC): "dest_file_name": dest_file_name, "file_size": file.sim_size, "real_file_path": file.sim_path if file.real else None, + "health_status": file.health_status, }, packet_payload_size=file.sim_size, status_code=FTPStatusCode.OK if is_response else None, diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index d45ef3a6..162678a0 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from enum import Enum from typing import Any, Dict, Optional @@ -43,9 +44,6 @@ class Service(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - self.health_state_visible = SoftwareHealthState.UNUSED - self.health_state_actual = SoftwareHealthState.UNUSED - def _can_perform_action(self) -> bool: """ Checks if the service can perform actions. @@ -98,6 +96,7 @@ class Service(IOSoftware): rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) return rm + @abstractmethod def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -118,7 +117,6 @@ class Service(IOSoftware): if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Stopping service {self.name}") self.operating_state = ServiceOperatingState.STOPPED - self.health_state_actual = SoftwareHealthState.UNUSED def start(self, **kwargs) -> None: """Start the service.""" @@ -129,42 +127,39 @@ class Service(IOSoftware): if self.operating_state == ServiceOperatingState.STOPPED: self.sys_log.info(f"Starting service {self.name}") self.operating_state = ServiceOperatingState.RUNNING - self.health_state_actual = SoftwareHealthState.GOOD + # set software health state to GOOD if initially set to UNUSED + if self.health_state_actual == SoftwareHealthState.UNUSED: + self.set_health_state(SoftwareHealthState.GOOD) def pause(self) -> None: """Pause the service.""" if self.operating_state == ServiceOperatingState.RUNNING: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.PAUSED - self.health_state_actual = SoftwareHealthState.OVERWHELMED def resume(self) -> None: """Resume paused service.""" if self.operating_state == ServiceOperatingState.PAUSED: self.sys_log.info(f"Resuming service {self.name}") self.operating_state = ServiceOperatingState.RUNNING - self.health_state_actual = SoftwareHealthState.GOOD def restart(self) -> None: """Restart running service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.RESTARTING - self.health_state_actual = SoftwareHealthState.OVERWHELMED self.restart_countdown = self.restart_duration def disable(self) -> None: """Disable the service.""" self.sys_log.info(f"Disabling Application {self.name}") self.operating_state = ServiceOperatingState.DISABLED - self.health_state_actual = SoftwareHealthState.OVERWHELMED def enable(self) -> None: """Enable the disabled service.""" if self.operating_state == ServiceOperatingState.DISABLED: self.sys_log.info(f"Enabling Application {self.name}") self.operating_state = ServiceOperatingState.STOPPED - self.health_state_actual = SoftwareHealthState.OVERWHELMED def apply_timestep(self, timestep: int) -> None: """ @@ -181,5 +176,4 @@ class Service(IOSoftware): if self.restart_countdown <= 0: _LOGGER.debug(f"Restarting finished for service {self.name}") self.operating_state = ServiceOperatingState.RUNNING - self.health_state_actual = SoftwareHealthState.GOOD self.restart_countdown -= 1 diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index afd6cb74..eaea6bb1 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -13,6 +13,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import SoftwareHealthState _LOGGER = getLogger(__name__) @@ -123,7 +124,10 @@ class WebServer(Service): # get all users if db_client.query("SELECT"): # query succeeded + self.set_health_state(SoftwareHealthState.GOOD) response.status_code = HttpStatusCode.OK + else: + self.set_health_state(SoftwareHealthState.COMPROMISED) return response except Exception: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 27d5b3b3..662db08e 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -2,8 +2,8 @@ import copy from abc import abstractmethod from datetime import datetime from enum import Enum -from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from ipaddress import IPv4Address, IPv4Network +from typing import Any, Dict, Optional, Union from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -38,12 +38,12 @@ class SoftwareHealthState(Enum): "Unused state." GOOD = 1 "The software is in a good and healthy condition." - COMPROMISED = 2 - "The software's security has been compromised." - OVERWHELMED = 3 - "he software is overwhelmed and not functioning properly." - PATCHING = 4 + PATCHING = 2 "The software is undergoing patching or updates." + COMPROMISED = 3 + "The software's security has been compromised." + OVERWHELMED = 4 + "he software is overwhelmed and not functioning properly." class SoftwareCriticality(Enum): @@ -71,9 +71,9 @@ class Software(SimComponent): name: str "The name of the software." - health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD + health_state_actual: SoftwareHealthState = SoftwareHealthState.UNUSED "The actual health state of the software." - health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD + health_state_visible: SoftwareHealthState = SoftwareHealthState.UNUSED "The health state of the software visible to the red agent." criticality: SoftwareCriticality = SoftwareCriticality.LOWEST "The criticality level of the software." @@ -195,8 +195,9 @@ class Software(SimComponent): def patch(self) -> None: """Perform a patch on the software.""" - self._patching_countdown = self.patching_duration - self.set_health_state(SoftwareHealthState.PATCHING) + if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD): + self._patching_countdown = self.patching_duration + self.set_health_state(SoftwareHealthState.PATCHING) def _update_patch_status(self) -> None: """Update the patch status of the software.""" @@ -282,7 +283,7 @@ class IOSoftware(Software): Returns true if the software can perform actions. """ - if self.software_manager and self.software_manager.node.operating_state is NodeOperatingState.OFF: + if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON: _LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.") return False return True @@ -303,13 +304,13 @@ class IOSoftware(Software): """ # if over or at capacity, set to overwhelmed if len(self._connections) >= self.max_sessions: - self.health_state_actual = SoftwareHealthState.OVERWHELMED + self.set_health_state(SoftwareHealthState.OVERWHELMED) self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") return False else: # if service was previously overwhelmed, set to good because there is enough space for connections if self.health_state_actual == SoftwareHealthState.OVERWHELMED: - self.health_state_actual = SoftwareHealthState.GOOD + self.set_health_state(SoftwareHealthState.GOOD) # check that connection already doesn't exist if not self._connections.get(connection_id): @@ -350,19 +351,22 @@ class IOSoftware(Software): self, payload: Any, session_id: Optional[str] = None, - dest_ip_address: Optional[IPv4Address] = None, + dest_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, dest_port: Optional[Port] = None, **kwargs, ) -> bool: """ - Sends a payload to the SessionManager. + Sends a payload to the SessionManager for network transmission. + + This method is responsible for initiating the process of sending network payloads. It supports both + unicast and Layer 3 broadcast transmissions. For broadcasts, the destination IP should be specified + as an IPv4Network. It delegates the actual sending process to the SoftwareManager. :param payload: The payload to be sent. - :param dest_ip_address: The ip address of the payload destination. - :param dest_port: The port of the payload destination. - :param session_id: The Session ID the payload is to originate from. Optional. - - :return: True if successful, False otherwise. + :param dest_ip_address: The IP address or network (for broadcasts) of the payload destination. + :param dest_port: The destination port for the payload. Optional. + :param session_id: The Session ID from which the payload originates. Optional. + :return: True if the payload was successfully sent, False otherwise. """ if not self._can_perform_action(): return False diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 4c1d7ce7..3e9be3bb 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -19,7 +19,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -489,6 +489,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 29b7937b..0c3872b0 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -23,7 +23,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -493,6 +493,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 54727790..87bcc14f 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -29,7 +29,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -500,6 +500,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: @@ -929,6 +946,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index f677b4e0..84b1c15f 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -27,7 +27,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -500,6 +500,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index b788e33f..62826cd4 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -23,7 +23,7 @@ game: - UDP agents: - - ref: client_1_green_user + - ref: client_2_green_user team: GREEN type: GreenWebBrowsingAgent observation_space: @@ -501,6 +501,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/conftest.py b/tests/conftest.py index 1ab07dd8..c37226a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,6 +40,9 @@ from primaite.simulator.network.hardware.base import Link, Node class TestService(Service): """Test Service class""" + def describe_state(self) -> Dict: + return super().describe_state() + def __init__(self, **kwargs): kwargs["name"] = "TestService" kwargs["port"] = Port.HTTP @@ -60,7 +63,7 @@ class TestApplication(Application): super().__init__(**kwargs) def describe_state(self) -> Dict: - pass + return super().describe_state() @pytest.fixture(scope="function") @@ -167,7 +170,7 @@ def example_network() -> Network: -------------- -------------- | client_1 |----- ----| server_1 | -------------- | -------------- -------------- -------------- | -------------- - ------| switch_1 |------| router_1 |------| switch_2 |------ + ------| switch_2 |------| router_1 |------| switch_1 |------ -------------- | -------------- -------------- -------------- | -------------- | client_2 |---- ----| server_2 | -------------- -------------- diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 5206561b..992ed533 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -22,7 +22,7 @@ def test_data_manipulation(uc2_network): assert db_client.query("SELECT") # Now we run the DataManipulationBot - db_manipulation_bot.run() + db_manipulation_bot.attack() # Now check that the DB client on the web_server cannot query the users table on the database assert not db_client.query("SELECT") diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py new file mode 100644 index 00000000..b9ecb28b --- /dev/null +++ b/tests/integration_tests/network/test_broadcast.py @@ -0,0 +1,180 @@ +from ipaddress import IPv4Address, IPv4Network +from typing import Any, Dict, List, Tuple + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.hardware.nodes.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.services.service import Service + + +class BroadcastService(Service): + """A service for sending broadcast and unicast messages over a network.""" + + def __init__(self, **kwargs): + # Set default service properties for broadcasting + kwargs["name"] = "BroadcastService" + kwargs["port"] = Port.HTTP + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + # Implement state description for the service + pass + + def unicast(self, ip_address: IPv4Address): + # Send a unicast payload to a specific IP address + super().send( + payload="unicast", + dest_ip_address=ip_address, + dest_port=Port.HTTP, + ) + + def broadcast(self, ip_network: IPv4Network): + # Send a broadcast payload to an entire IP network + super().send( + payload="broadcast", + dest_ip_address=ip_network, + dest_port=Port.HTTP, + ) + + +class BroadcastClient(Application): + """A client application to receive broadcast and unicast messages.""" + + payloads_received: List = [] + + def __init__(self, **kwargs): + # Set default client properties + kwargs["name"] = "BroadcastClient" + kwargs["port"] = Port.HTTP + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + # Implement state description for the application + pass + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + # Append received payloads to the list and print a message + self.payloads_received.append(payload) + print(f"Payload: {payload} received on node {self.sys_log.hostname}") + + +@pytest.fixture(scope="function") +def broadcast_network() -> Network: + network = Network() + + client_1 = Computer( + hostname="client_1", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + client_1.power_on() + client_1.software_manager.install(BroadcastClient) + application_1 = client_1.software_manager.software["BroadcastClient"] + application_1.run() + + client_2 = Computer( + hostname="client_2", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + client_2.power_on() + client_2.software_manager.install(BroadcastClient) + application_2 = client_2.software_manager.software["BroadcastClient"] + application_2.run() + + server_1 = Server( + hostname="server_1", + ip_address="192.168.1.1", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server_1.power_on() + + server_1.software_manager.install(BroadcastService) + service: BroadcastService = server_1.software_manager.software["BroadcastService"] + service.start() + + switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1.power_on() + + network.connect(endpoint_a=client_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[1]) + network.connect(endpoint_a=client_2.ethernet_port[1], endpoint_b=switch_1.switch_ports[2]) + network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_1.switch_ports[3]) + + return network + + +@pytest.fixture(scope="function") +def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]: + client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[ + "BroadcastClient" + ] + client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[ + "BroadcastClient" + ] + service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[ + "BroadcastService" + ] + + return service, client_1, client_2 + + +def test_broadcast_correct_subnet(broadcast_service_and_clients): + service, client_1, client_2 = broadcast_service_and_clients + + assert not client_1.payloads_received + assert not client_2.payloads_received + + service.broadcast(IPv4Network("192.168.1.0/24")) + + assert client_1.payloads_received == ["broadcast"] + assert client_2.payloads_received == ["broadcast"] + + +def test_broadcast_incorrect_subnet(broadcast_service_and_clients): + service, client_1, client_2 = broadcast_service_and_clients + + assert not client_1.payloads_received + assert not client_2.payloads_received + + service.broadcast(IPv4Network("192.168.2.0/24")) + + assert not client_1.payloads_received + assert not client_2.payloads_received + + +def test_unicast_correct_address(broadcast_service_and_clients): + service, client_1, client_2 = broadcast_service_and_clients + + assert not client_1.payloads_received + assert not client_2.payloads_received + + service.unicast(IPv4Address("192.168.1.2")) + + assert client_1.payloads_received == ["unicast"] + assert not client_2.payloads_received + + +def test_unicast_incorrect_address(broadcast_service_and_clients): + service, client_1, client_2 = broadcast_service_and_clients + + assert not client_1.payloads_received + assert not client_2.payloads_received + + service.unicast(IPv4Address("192.168.2.2")) + + assert not client_1.payloads_received + assert not client_2.payloads_received diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 6053c457..042debca 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -1,11 +1,16 @@ +from ipaddress import IPv4Address from typing import Tuple import pytest +from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer @pytest.fixture(scope="function") @@ -34,6 +39,69 @@ def pc_a_pc_b_router_1() -> Tuple[Node, Node, Router]: return pc_a, pc_b, router_1 +@pytest.fixture(scope="function") +def multi_hop_network() -> Network: + network = Network() + + # Configure PC A + pc_a = Computer( + hostname="pc_a", + ip_address="192.168.0.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.0.1", + start_up_duration=0, + ) + pc_a.power_on() + network.add_node(pc_a) + + # Configure Router 1 + router_1 = Router(hostname="router_1", start_up_duration=0) + router_1.power_on() + network.add_node(router_1) + + # Configure the connection between PC A and Router 1 port 2 + router_1.configure_port(2, "192.168.0.1", "255.255.255.0") + network.connect(pc_a.ethernet_port[1], router_1.ethernet_ports[2]) + router_1.enable_port(2) + + # Configure Router 1 ACLs + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + + # Configure PC B + pc_b = Computer( + hostname="pc_b", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.2.1", + start_up_duration=0, + ) + pc_b.power_on() + network.add_node(pc_b) + + # Configure Router 2 + router_2 = Router(hostname="router_2", start_up_duration=0) + router_2.power_on() + network.add_node(router_2) + + # Configure the connection between PC B and Router 2 port 2 + router_2.configure_port(2, "192.168.2.1", "255.255.255.0") + network.connect(pc_b.ethernet_port[1], router_2.ethernet_ports[2]) + router_2.enable_port(2) + + # Configure Router 2 ACLs + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_2.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + + # Configure the connection between Router 1 port 1 and Router 2 port 1 + router_2.configure_port(1, "192.168.1.2", "255.255.255.252") + router_1.configure_port(1, "192.168.1.1", "255.255.255.252") + network.connect(router_1.ethernet_ports[1], router_2.ethernet_ports[1]) + router_1.enable_port(1) + router_2.enable_port(1) + return network + + def test_ping_default_gateway(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 @@ -50,3 +118,68 @@ def test_host_on_other_subnet(pc_a_pc_b_router_1): pc_a, pc_b, router_1 = pc_a_pc_b_router_1 assert pc_a.ping("192.168.1.10") + + +def test_no_route_no_ping(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + pc_b = multi_hop_network.get_node_by_hostname("pc_b") + + assert not pc_a.ping(pc_b.ethernet_port[1].ip_address) + + +def test_with_routes_can_ping(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + pc_b = multi_hop_network.get_node_by_hostname("pc_b") + + router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa + router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa + + # Configure Route from Router 1 to PC B subnet + router_1.route_table.add_route( + address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2" + ) + + # Configure Route from Router 2 to PC A subnet + router_2.route_table.add_route( + address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1" + ) + + assert pc_a.ping(pc_b.ethernet_port[1].ip_address) + + +def test_routing_services(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + + pc_b = multi_hop_network.get_node_by_hostname("pc_b") + + pc_a.software_manager.install(NTPClient) + ntp_client = pc_a.software_manager.software["NTPClient"] + ntp_client.start() + + pc_b.software_manager.install(NTPServer) + pc_b.software_manager.software["NTPServer"].start() + + ntp_client.configure(ntp_server_ip_address=pc_b.ethernet_port[1].ip_address) + + router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa + router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa + + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.NTP, dst_port=Port.NTP, position=21) + + assert ntp_client.time is None + ntp_client.request_time() + assert ntp_client.time is None + + # Configure Route from Router 1 to PC B subnet + router_1.route_table.add_route( + address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2" + ) + + # Configure Route from Router 2 to PC A subnet + router_2.route_table.add_route( + address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1" + ) + + ntp_client.request_time() + assert ntp_client.time is not None diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py index 85028d75..fb768127 100644 --- a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -90,7 +90,7 @@ def test_repeating_dos_attack(dos_bot_and_db_server): assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED db_server_service.clear_connections() - db_server_service.health_state_actual = SoftwareHealthState.GOOD + db_server_service.set_health_state(SoftwareHealthState.GOOD) assert len(db_server_service.connections) == 0 computer.apply_timestep(timestep=1) @@ -121,7 +121,7 @@ def test_non_repeating_dos_attack(dos_bot_and_db_server): assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED db_server_service.clear_connections() - db_server_service.health_state_actual = SoftwareHealthState.GOOD + db_server_service.set_health_state(SoftwareHealthState.GOOD) assert len(db_server_service.connections) == 0 computer.apply_timestep(timestep=1) diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index 46be5e55..60497f22 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -24,8 +24,8 @@ def populated_node(application_class) -> Tuple[Application, Computer]: return app, computer -def test_service_on_offline_node(application_class): - """Test to check that the service cannot be interacted with when node it is on is off.""" +def test_application_on_offline_node(application_class): + """Test to check that the application cannot be interacted with when node it is on is off.""" computer: Computer = Computer( hostname="test_computer", ip_address="192.168.1.2", @@ -49,8 +49,8 @@ def test_service_on_offline_node(application_class): assert app.operating_state is ApplicationOperatingState.CLOSED -def test_server_turns_off_service(populated_node): - """Check that the service is turned off when the server is turned off""" +def test_server_turns_off_application(populated_node): + """Check that the application is turned off when the server is turned off""" app, computer = populated_node assert computer.operating_state is NodeOperatingState.ON @@ -65,8 +65,8 @@ def test_server_turns_off_service(populated_node): assert app.operating_state is ApplicationOperatingState.CLOSED -def test_service_cannot_be_turned_on_when_server_is_off(populated_node): - """Check that the service cannot be started when the server is off.""" +def test_application_cannot_be_turned_on_when_computer_is_off(populated_node): + """Check that the application cannot be started when the computer is off.""" app, computer = populated_node assert computer.operating_state is NodeOperatingState.ON @@ -86,8 +86,8 @@ def test_service_cannot_be_turned_on_when_server_is_off(populated_node): assert app.operating_state is ApplicationOperatingState.CLOSED -def test_server_turns_on_service(populated_node): - """Check that turning on the server turns on service.""" +def test_computer_runs_applications(populated_node): + """Check that turning on the computer will turn on applications.""" app, computer = populated_node assert computer.operating_state is NodeOperatingState.ON @@ -109,13 +109,14 @@ def test_server_turns_on_service(populated_node): assert computer.operating_state is NodeOperatingState.ON assert app.operating_state is ApplicationOperatingState.RUNNING - computer.start_up_duration = 0 - computer.shut_down_duration = 0 - computer.power_off() + for i in range(computer.start_up_duration + 1): + computer.apply_timestep(timestep=i) assert computer.operating_state is NodeOperatingState.OFF assert app.operating_state is ApplicationOperatingState.CLOSED computer.power_on() + for i in range(computer.start_up_duration + 1): + computer.apply_timestep(timestep=i) assert computer.operating_state is NodeOperatingState.ON assert app.operating_state is ApplicationOperatingState.RUNNING diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py index aab1e4da..9b0084bd 100644 --- a/tests/integration_tests/system/test_service_on_node.py +++ b/tests/integration_tests/system/test_service_on_node.py @@ -117,13 +117,14 @@ def test_server_turns_on_service(populated_node): assert server.operating_state is NodeOperatingState.ON assert service.operating_state is ServiceOperatingState.RUNNING - server.start_up_duration = 0 - server.shut_down_duration = 0 - server.power_off() + for i in range(server.start_up_duration + 1): + server.apply_timestep(timestep=i) assert server.operating_state is NodeOperatingState.OFF assert service.operating_state is ServiceOperatingState.STOPPED server.power_on() + for i in range(server.start_up_duration + 1): + server.apply_timestep(timestep=i) assert server.operating_state is NodeOperatingState.ON assert service.operating_state is ServiceOperatingState.RUNNING diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 5fe5df16..b6f7a86d 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -53,12 +53,12 @@ def test_node_os_scan(node, service, application): # TODO implement processes # add services to node - service.health_state_actual = SoftwareHealthState.COMPROMISED + service.set_health_state(SoftwareHealthState.COMPROMISED) node.install_service(service=service) assert service.health_state_visible == SoftwareHealthState.UNUSED # add application to node - application.health_state_actual = SoftwareHealthState.COMPROMISED + application.set_health_state(SoftwareHealthState.COMPROMISED) node.install_application(application=application) assert application.health_state_visible == SoftwareHealthState.UNUSED @@ -101,7 +101,7 @@ def test_node_red_scan(node, service, application): assert service.revealed_to_red is False # add application to node - application.health_state_actual = SoftwareHealthState.COMPROMISED + application.set_health_state(SoftwareHealthState.COMPROMISED) node.install_application(application=application) assert application.revealed_to_red is False diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index e348838e..7667a59f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien from primaite.simulator.system.services.database.database_service import DatabaseService +def filter_keys_nested_item(data, keys): + stack = [(data, {})] + while stack: + current, filtered = stack.pop() + if isinstance(current, dict): + for k, v in current.items(): + if k in keys: + filtered[k] = filter_keys_nested_item(v, keys) + elif isinstance(v, (dict, list)): + stack.append((v, {})) + elif isinstance(current, list): + for item in current: + stack.append((item, {})) + return filtered + + @pytest.fixture(scope="function") def network(example_network) -> Network: assert len(example_network.routers) is 1 @@ -59,10 +75,10 @@ def test_reset_network(network): assert client_1.operating_state is NodeOperatingState.ON assert server_1.operating_state is NodeOperatingState.ON - - assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps( - state_before, sort_keys=True, indent=2 - ) + # don't worry if UUIDs change + a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) + b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) + assert a == b def test_creating_container(): diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_actions.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py new file mode 100644 index 00000000..6247a100 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_applications.py @@ -0,0 +1,50 @@ +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.software import SoftwareHealthState + + +def test_scan(application): + assert application.operating_state == ApplicationOperatingState.CLOSED + assert application.health_state_visible == SoftwareHealthState.UNUSED + + application.run() + assert application.operating_state == ApplicationOperatingState.RUNNING + assert application.health_state_visible == SoftwareHealthState.UNUSED + + application.scan() + assert application.operating_state == ApplicationOperatingState.RUNNING + assert application.health_state_visible == SoftwareHealthState.GOOD + + +def test_run_application(application): + assert application.operating_state == ApplicationOperatingState.CLOSED + assert application.health_state_actual == SoftwareHealthState.UNUSED + + application.run() + assert application.operating_state == ApplicationOperatingState.RUNNING + assert application.health_state_actual == SoftwareHealthState.GOOD + + +def test_close_application(application): + application.run() + assert application.operating_state == ApplicationOperatingState.RUNNING + assert application.health_state_actual == SoftwareHealthState.GOOD + + application.close() + assert application.operating_state == ApplicationOperatingState.CLOSED + assert application.health_state_actual == SoftwareHealthState.GOOD + + +def test_application_describe_states(application): + assert application.operating_state == ApplicationOperatingState.CLOSED + assert application.health_state_actual == SoftwareHealthState.UNUSED + + assert SoftwareHealthState.UNUSED.value == application.describe_state().get("health_state_actual") + + application.run() + assert SoftwareHealthState.GOOD.value == application.describe_state().get("health_state_actual") + + application.set_health_state(SoftwareHealthState.COMPROMISED) + assert SoftwareHealthState.COMPROMISED.value == application.describe_state().get("health_state_actual") + + application.patch() + assert SoftwareHealthState.PATCHING.value == application.describe_state().get("health_state_actual") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 134f82bd..941a465e 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -2,6 +2,7 @@ from ipaddress import IPv4Address import pytest +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer @@ -42,6 +43,7 @@ def test_ftp_client_store_file(ftp_client): "dest_folder_name": "downloads", "dest_file_name": "file.txt", "file_size": 24, + "health_status": FileSystemItemHealthStatus.GOOD, }, packet_payload_size=24, status_code=FTPStatusCode.OK, diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index 2b26c932..137e74d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -1,5 +1,6 @@ import pytest +from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.server import Server @@ -41,6 +42,7 @@ def test_ftp_server_store_file(ftp_server): "dest_folder_name": "downloads", "dest_file_name": "file.txt", "file_size": 24, + "health_status": FileSystemItemHealthStatus.GOOD, }, packet_payload_size=24, ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py index 016cf011..ac36c660 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py @@ -19,55 +19,146 @@ def test_scan(service): def test_start_service(service): assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_actual == SoftwareHealthState.UNUSED service.start() assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_actual == SoftwareHealthState.GOOD def test_stop_service(service): service.start() assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_actual == SoftwareHealthState.GOOD service.stop() assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_actual == SoftwareHealthState.GOOD def test_pause_and_resume_service(service): assert service.operating_state == ServiceOperatingState.STOPPED service.resume() assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_actual == SoftwareHealthState.UNUSED service.start() + assert service.health_state_actual == SoftwareHealthState.GOOD service.pause() assert service.operating_state == ServiceOperatingState.PAUSED + assert service.health_state_actual == SoftwareHealthState.GOOD service.resume() assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_actual == SoftwareHealthState.GOOD def test_restart(service): assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_actual == SoftwareHealthState.UNUSED service.restart() + # Service is STOPPED. Restart will only work if the service was PAUSED or RUNNING assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_actual == SoftwareHealthState.UNUSED service.start() + assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_actual == SoftwareHealthState.GOOD service.restart() + # Service is RUNNING. Restart should work assert service.operating_state == ServiceOperatingState.RESTARTING + assert service.health_state_actual == SoftwareHealthState.GOOD timestep = 0 while service.operating_state == ServiceOperatingState.RESTARTING: service.apply_timestep(timestep) + assert service.health_state_actual == SoftwareHealthState.GOOD timestep += 1 assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_actual == SoftwareHealthState.GOOD + + +def test_restart_compromised(service): + service.start() + assert service.health_state_actual == SoftwareHealthState.GOOD + + # compromise the service + service.set_health_state(SoftwareHealthState.COMPROMISED) + + service.restart() + assert service.operating_state == ServiceOperatingState.RESTARTING + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + """ + Service should be compromised even after reset. + + Only way to remove compromised status is via patching. + """ + + timestep = 0 + while service.operating_state == ServiceOperatingState.RESTARTING: + service.apply_timestep(timestep) + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + timestep += 1 + + assert service.operating_state == ServiceOperatingState.RUNNING + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + +def test_compromised_service_remains_compromised(service): + """ + Tests that a compromised service stays compromised. + + The only way that the service can be uncompromised is by running patch. + """ + service.start() + assert service.health_state_actual == SoftwareHealthState.GOOD + + service.set_health_state(SoftwareHealthState.COMPROMISED) + + service.stop() + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + service.start() + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + service.disable() + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + service.enable() + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + service.pause() + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + service.resume() + assert service.health_state_actual == SoftwareHealthState.COMPROMISED + + +def test_service_patching(service): + service.start() + assert service.health_state_actual == SoftwareHealthState.GOOD + + service.set_health_state(SoftwareHealthState.COMPROMISED) + + service.patch() + assert service.health_state_actual == SoftwareHealthState.PATCHING + + for i in range(service.patching_duration + 1): + service.apply_timestep(i) + + assert service.health_state_actual == SoftwareHealthState.GOOD def test_enable_disable(service): service.disable() assert service.operating_state == ServiceOperatingState.DISABLED + assert service.health_state_actual == SoftwareHealthState.UNUSED service.enable() assert service.operating_state == ServiceOperatingState.STOPPED + assert service.health_state_actual == SoftwareHealthState.UNUSED def test_overwhelm_service(service): @@ -76,13 +167,13 @@ def test_overwhelm_service(service): uuid = str(uuid4()) assert service.add_connection(connection_id=uuid) # should be true - assert service.health_state_actual is SoftwareHealthState.GOOD + assert service.health_state_actual == SoftwareHealthState.GOOD assert not service.add_connection(connection_id=uuid) # fails because connection already exists - assert service.health_state_actual is SoftwareHealthState.GOOD + assert service.health_state_actual == SoftwareHealthState.GOOD assert service.add_connection(connection_id=str(uuid4())) # succeed - assert service.health_state_actual is SoftwareHealthState.GOOD + assert service.health_state_actual == SoftwareHealthState.GOOD assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity assert service.health_state_actual is SoftwareHealthState.OVERWHELMED diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index bbccda27..64277356 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -1,5 +1,6 @@ import pytest +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.protocols.http import ( HttpRequestMethod, @@ -15,7 +16,11 @@ from primaite.simulator.system.services.web_server.web_server import WebServer @pytest.fixture(scope="function") def web_server() -> Server: node = Server( - hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" + hostname="web_server", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, ) node.software_manager.install(software_class=WebServer) node.software_manager.software.get("WebServer").start() diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py new file mode 100644 index 00000000..e77cd895 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -0,0 +1,29 @@ +from typing import Dict + +import pytest + +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.sys_log import SysLog +from primaite.simulator.system.software import Software, SoftwareHealthState + + +class TestSoftware(Software): + def describe_state(self) -> Dict: + pass + + +@pytest.fixture(scope="function") +def software(file_system): + return TestSoftware( + name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + ) + + +def test_software_creation(software): + assert software is not None + + +def test_software_set_health_state(software): + assert software.health_state_actual == SoftwareHealthState.UNUSED + software.set_health_state(SoftwareHealthState.GOOD) + assert software.health_state_actual == SoftwareHealthState.GOOD