From 64b9ba3ecf2e1865e902917bd80d05bac70ab0bd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:21:03 +0000 Subject: [PATCH 01/18] Make environment reset reinstantiate the game --- src/primaite/game/game.py | 5 +--- .../notebooks/training_example_sb3.ipynb | 4 +-- src/primaite/session/environment.py | 27 ++++++++++++------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1fd0dc8b..091438ce 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -67,9 +67,6 @@ class PrimaiteGame: self.step_counter: int = 0 """Current timestep within the episode.""" - self.episode_counter: int = 0 - """Current episode number.""" - self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" @@ -163,7 +160,7 @@ class PrimaiteGame: return True return False - def reset(self) -> None: + def reset(self) -> None: # TODO: deprecated - remove me """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e5085c5e..164142b2 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -38,7 +38,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game)" + "gym = PrimaiteGymEnv(game_config=cfg)" ] }, { @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=1000)\n" + "model.learn(total_timesteps=10)\n" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a3831bc1..ad770f8f 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -18,11 +18,18 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame): + def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() - self.game: "PrimaiteGame" = game + self.game_config: Dict = game_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + """Current game.""" self.agent: ProxyAgent = self.game.rl_agents[0] + """The agent within the game that is controlled by the RL algorithm.""" + + self.episode_counter: int = 0 + """Current episode number.""" def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -45,13 +52,13 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), @@ -63,10 +70,12 @@ class PrimaiteGymEnv(gymnasium.Env): def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" print( - f"Resetting environment, episode {self.game.episode_counter}, " + f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" ) - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.agent = self.game.rl_agents[0] + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -107,7 +116,7 @@ class PrimaiteRayEnv(gymnasium.Env): :type env_config: Dict[str, PrimaiteGame] """ self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) - self.env.game.episode_counter -= 1 + self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -194,13 +203,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards, From f82506023bcd978ff28e28a25491eb6b6facdee7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:29:27 +0000 Subject: [PATCH 02/18] Delete set_original_state method definitions --- src/primaite/game/game.py | 2 - src/primaite/simulator/core.py | 10 +--- src/primaite/simulator/domain/account.py | 13 ----- src/primaite/simulator/file_system/file.py | 9 ---- .../simulator/file_system/file_system.py | 11 ---- .../file_system/file_system_item_abc.py | 5 -- src/primaite/simulator/file_system/folder.py | 17 ------- src/primaite/simulator/network/container.py | 7 --- .../simulator/network/hardware/base.py | 50 ------------------- .../network/hardware/nodes/router.py | 48 ------------------ src/primaite/simulator/sim_container.py | 4 -- .../system/applications/application.py | 6 --- .../system/applications/database_client.py | 8 --- .../red_applications/data_manipulation_bot.py | 15 ------ .../applications/red_applications/dos_bot.py | 16 ------ .../system/applications/web_browser.py | 8 --- .../simulator/system/processes/process.py | 6 --- .../services/database/database_service.py | 13 ----- .../system/services/dns/dns_client.py | 7 --- .../system/services/dns/dns_server.py | 7 --- .../system/services/ftp/ftp_client.py | 7 --- .../system/services/ftp/ftp_server.py | 7 --- .../simulator/system/services/service.py | 6 --- .../system/services/web_server/web_server.py | 7 --- src/primaite/simulator/system/software.py | 19 ------- .../_simulator/_domain/test_account.py | 2 - .../_file_system/test_file_system.py | 1 - .../_simulator/_network/test_container.py | 1 - .../_red_applications/test_dos_bot.py | 2 - 29 files changed, 1 insertion(+), 313 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 091438ce..bd7ed2cd 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -386,6 +386,4 @@ class PrimaiteGame: else: _LOGGER.warning(f"agent type {agent_type} not found") - game.simulation.set_original_state() - return game diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 98a7e8db..e21ce9eb 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,8 +153,6 @@ class SimComponent(BaseModel): uuid: str """The component UUID.""" - _original_state: Dict = {} - def __init__(self, **kwargs): if not kwargs.get("uuid"): kwargs["uuid"] = str(uuid4()) @@ -162,15 +160,9 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - # @abstractmethod - def set_original_state(self): - """Sets the original state.""" - pass - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - for key, value in self._original_state.items(): - self.__setattr__(key, value) + pass def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d9dad06a..186caf5b 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,19 +42,6 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "num_logons", - "num_logoffs", - "num_group_changes", - "username", - "password", - "account_type", - "enabled", - } - self._original_state = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 608a1d78..4cd5cdbb 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,15 +73,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}") - super().set_original_state() - vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ee80587d..a7252a2d 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -34,17 +34,6 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}") - for folder in self.folders.values(): - folder.set_original_state() - # Capture a list of all 'original' file uuids - original_keys = list(self.folders.keys()) - vals_to_include = {"sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_folder_uuids"] = original_keys - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index c3e1426b..fbe5f4b3 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} - self._original_state = self.model_dump(include=vals_to_keep) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 13fdc597..39c3dad8 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -49,23 +49,6 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}") - for file in self.files.values(): - file.set_original_state() - super().set_original_state() - vals_to_include = { - "scan_duration", - "scan_countdown", - "red_scan_duration", - "red_scan_countdown", - "restore_duration", - "restore_countdown", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_file_uuids"] = list(self.files.keys()) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 8989a60f..48205bbd 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,13 +45,6 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def set_original_state(self): - """Sets the original state.""" - for node in self.nodes.values(): - node.set_original_state() - for link in self.links.values(): - link.set_original_state() - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 01dd736d..68f3816d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,13 +123,6 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" super().reset_component_for_episode(episode) @@ -349,14 +342,6 @@ class SwitchPort(SimComponent): kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -506,14 +491,6 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"bandwidth", "current_load"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -1033,33 +1010,6 @@ class Node(SimComponent): self.arp.nics = self.nics self.session_manager.software_manager = self.software_manager self._install_system_software() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - for software in self.software_manager.software.values(): - software.set_original_state() - - self.file_system.set_original_state() - - for nic in self.nics.values(): - nic.set_original_state() - - vals_to_include = { - "hostname", - "default_gateway", - "operating_state", - "revealed_to_red", - "start_up_duration", - "start_up_countdown", - "shut_down_duration", - "shut_down_countdown", - "is_resetting", - "node_scan_duration", - "node_scan_countdown", - "red_scan_countdown", - } - self._original_state = self.model_dump(include=vals_to_include) def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 9a34be0b..4b379be0 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -53,11 +53,6 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -101,28 +96,6 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.implicit_rule.set_original_state() - vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - - for i, rule in enumerate(self._acl): - if not rule: - continue - self._default_config[i] = {"action": rule.action.name} - if rule.src_ip_address: - self._default_config[i]["src_ip"] = str(rule.src_ip_address) - if rule.dst_ip_address: - self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) - if rule.src_port: - self._default_config[i]["src_port"] = rule.src_port.name - if rule.dst_port: - self._default_config[i]["dst_port"] = rule.dst_port.name - if rule.protocol: - self._default_config[i]["protocol"] = rule.protocol.name def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" @@ -389,11 +362,6 @@ class RouteEntry(SimComponent): metric: float = 0.0 "The cost metric for this route. Default is 0.0." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} - self._original_values = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -426,11 +394,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - self._original_state["routes_orig"] = self.routes - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.routes.clear() @@ -808,16 +771,6 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.acl.set_original_state() - self.route_table.set_original_state() - super().set_original_state() - vals_to_include = {"num_ports"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() @@ -987,7 +940,6 @@ class Router(Node): nic.ip_address = ip_address nic.subnet_mask = subnet_mask self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}") - self.set_original_state() def enable_port(self, port: int): """ diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 896861e6..18ed894c 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,10 +21,6 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - self.network.set_original_state() - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.network.reset_component_for_episode(episode) diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 322ac808..513606a9 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,12 +38,6 @@ class Application(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 2e0f4e3f..d05472d4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -30,14 +30,6 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"} - self._original_state.update(self.model_dump(include=vals_to_include)) def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index a844f059..bd4048c4 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -49,21 +49,6 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "server_ip_address", - "payload", - "server_password", - "port_scan_p_of_success", - "data_manipulation_p_of_success", - "attack_stage", - "repeat", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index dfc48dd3..d4ea1a20 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -57,22 +57,6 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def set_original_state(self): - """Set the original state of the Denial of Service Bot.""" - _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "target_ip_address", - "target_port", - "payload", - "repeat", - "attack_stage", - "max_sessions", - "port_scan_p_of_success", - "dos_intensity", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index eef0ed5d..f1dbe3ef 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -47,16 +47,8 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.set_original_state() self.run() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index b753e3ad..458a6b5c 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,12 +24,6 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index d75b4424..4159c87c 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -40,19 +40,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "password", - "connections", - "backup_server_ip", - "latest_backup_directory", - "latest_backup_file_name", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 2d3879ff..3c034705 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,13 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_server"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.dns_cache.clear() diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8decf7e9..eab94766 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,13 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_table"} - self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.dns_table.clear() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 39bc57f0..457eaea9 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,13 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"connected"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index a82b0919..9534a5e9 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,13 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_password"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 162678a0..4102657c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -78,12 +78,6 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index eaea6bb1..5888e72a 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,13 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"last_response_status_code"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 662db08e..1fb8c989 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -96,19 +96,6 @@ class Software(SimComponent): _patching_countdown: Optional[int] = None "Current number of ticks left to patch the software." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "name", - "health_state_actual", - "health_state_visible", - "criticality", - "patching_count", - "scanning_count", - "revealed_to_red", - } - self._original_state = self.model_dump(include=vals_to_include) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -245,12 +232,6 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 01ad3871..695b15dd 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -7,7 +7,6 @@ from primaite.simulator.domain.account import Account, AccountType @pytest.fixture(scope="function") def account() -> Account: acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER) - acct.set_original_state() return acct @@ -39,7 +38,6 @@ def test_original_state(account): account.log_on() account.log_off() account.disable() - account.set_original_state() account.log_on() state = account.describe_state() diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 9366d173..2fe3f04c 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -189,7 +189,6 @@ def test_reset_file_system(file_system): # file and folder that existed originally file_system.create_file(file_name="test_file.zip") file_system.create_folder(folder_name="test_folder") - file_system.set_original_state() # create a new file file_system.create_file(file_name="new_file.txt") diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 7667a59f..994e5a45 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -33,7 +33,6 @@ def network(example_network) -> Network: assert len(example_network.computers) is 2 assert len(example_network.servers) is 2 - example_network.set_original_state() example_network.show() return example_network diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 71489171..da29a439 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -22,7 +22,6 @@ def dos_bot() -> DoSBot: dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) - dos_bot.set_original_state() return dos_bot @@ -51,7 +50,6 @@ def test_dos_bot_reset(dos_bot): dos_bot.configure( target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True ) - dos_bot.set_original_state() dos_bot.reset_component_for_episode(episode=1) # should reset to the configured value assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") From 72f4cc0a5073e79f7b9734b27739b3264513c6f8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:56:25 +0000 Subject: [PATCH 03/18] Remove reset methods from most classes --- src/primaite/simulator/file_system/file.py | 5 --- .../simulator/file_system/file_system.py | 26 ----------- src/primaite/simulator/file_system/folder.py | 26 ----------- .../simulator/network/hardware/base.py | 9 ---- .../network/hardware/nodes/router.py | 45 ++++++------------- .../system/applications/database_client.py | 6 --- .../red_applications/data_manipulation_bot.py | 5 --- .../applications/red_applications/dos_bot.py | 5 --- .../system/applications/web_browser.py | 8 ---- .../services/database/database_service.py | 6 --- .../system/services/dns/dns_client.py | 5 --- .../system/services/dns/dns_server.py | 7 --- .../system/services/ftp/ftp_client.py | 5 --- .../system/services/ftp/ftp_server.py | 6 --- .../system/services/ntp/ntp_client.py | 13 +----- .../system/services/ntp/ntp_server.py | 10 ----- .../system/services/web_server/web_server.py | 5 --- 17 files changed, 16 insertions(+), 176 deletions(-) diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 4cd5cdbb..d9b02e8e 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,11 +73,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") - super().reset_component_for_episode(episode) - @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index a7252a2d..8fd4e5d7 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -34,32 +34,6 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") - # Move any 'original' folder that have been deleted back to folders - original_folder_uuids = self._original_state["original_folder_uuids"] - for uuid in original_folder_uuids: - if uuid in self.deleted_folders: - folder = self.deleted_folders[uuid] - self.deleted_folders.pop(uuid) - self.folders[uuid] = folder - - # Clear any other deleted folders that aren't original (have been created by agent) - self.deleted_folders.clear() - - # Now clear all non-original folders created by agent - current_folder_uuids = list(self.folders.keys()) - for uuid in current_folder_uuids: - if uuid not in original_folder_uuids: - folder = self.folders[uuid] - self.folders.pop(uuid) - - # Now reset all remaining folders - for folder in self.folders.values(): - folder.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 39c3dad8..771dc7a0 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -49,32 +49,6 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") - # Move any 'original' file that have been deleted back to files - original_file_uuids = self._original_state["original_file_uuids"] - for uuid in original_file_uuids: - if uuid in self.deleted_files: - file = self.deleted_files[uuid] - self.deleted_files.pop(uuid) - self.files[uuid] = file - - # Clear any other deleted files that aren't original (have been created by agent) - self.deleted_files.clear() - - # Now clear all non-original files created by agent - current_file_uuids = list(self.files.keys()) - for uuid in current_file_uuids: - if uuid not in original_file_uuids: - file = self.files[uuid] - self.files.pop(uuid) - - # Now reset all remaining files - for file in self.files.values(): - file.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 68f3816d..67ac42c8 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1015,15 +1015,6 @@ class Node(SimComponent): """Reset the original state of the SimComponent.""" super().reset_component_for_episode(episode) - # Reset ARP Cache - self.arp.clear() - - # Reset ICMP - self.icmp.clear() - - # Reset Session Manager - self.session_manager.clear() - # Reset File System self.file_system.reset_component_for_episode(episode) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 4b379be0..aa154ad9 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -84,9 +84,7 @@ class AccessControlList(SimComponent): implicit_action: ACLAction implicit_rule: ACLRule max_acl_rules: int = 25 - _acl: List[Optional[ACLRule]] = [None] * 24 - _default_config: Dict[int, dict] = {} - """Config dict describing how the ACL list should look at episode start""" + _acl: List[Optional[ACLRule]] = [None] * 24 # TODO: this ignores the max_acl_rules and assumes it's default def __init__(self, **kwargs) -> None: if not kwargs.get("implicit_action"): @@ -97,26 +95,6 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.implicit_rule.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - self._reset_rules_to_default() - - def _reset_rules_to_default(self) -> None: - """Clear all ACL rules and set them to the default rules config.""" - self._acl = [None] * (self.max_acl_rules - 1) - for r_num, r_cfg in self._default_config.items(): - self.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("src_ip"), - dst_ip_address=r_cfg.get("dst_ip"), - position=r_num, - ) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -394,12 +372,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.routes.clear() - self.routes = self._original_state["routes_orig"] - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -1040,7 +1012,18 @@ class Router(Node): ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"], ) + + # Add the router's default ACL rules from the config. if "acl" in cfg: - new.acl._default_config = cfg["acl"] # save the config to allow resetting - new.acl._reset_rules_to_default() # read the config and apply rules + for r_num, r_cfg in cfg["acl"].items(): + new.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + return new diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index d05472d4..25730c38 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -31,12 +31,6 @@ class DatabaseClient(Application): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - self._query_success_tracker.clear() - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index bd4048c4..5fe951b7 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -49,11 +49,6 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index d4ea1a20..9dac6b25 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -57,11 +57,6 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index f1dbe3ef..6f2c479c 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -49,11 +49,6 @@ class WebBrowser(Application): super().__init__(**kwargs) self.run() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -72,9 +67,6 @@ class WebBrowser(Application): state["history"] = [hist_item.state() for hist_item in self.history] return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - def get_webpage(self, url: Optional[str] = None) -> bool: """ Retrieve the webpage. diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 4159c87c..5425ce75 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -40,12 +40,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 3c034705..967af6b2 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,11 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_cache.clear() - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index eab94766..4d0ebbb8 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,13 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_table.clear() - for key, value in self._original_state["dns_table_orig"].items(): - self.dns_table[key] = value - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 457eaea9..7c334ced 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,11 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 9534a5e9..c5330de2 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,12 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index e8c3d0cb..5e4ae53a 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -1,6 +1,6 @@ from datetime import datetime from ipaddress import IPv4Address -from typing import Dict, Optional +from typing import Dict, List, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket @@ -49,21 +49,12 @@ class NTPClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def send( self, payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: [Port] = Port.NTP, + dest_port: List[Port] = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 0a66384a..29a320f6 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -34,16 +34,6 @@ class NTPServer(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including - resetting any stateful properties or statistics, and clearing any message - queues. - """ - pass - def receive( self, payload: NTPPacket, diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5888e72a..5e4a6f6e 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,11 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. From 771a68dccba056428caee4fec18a9c0a4e4c2648 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 22 Feb 2024 22:43:14 +0000 Subject: [PATCH 04/18] #2238 - Implement NMNE detection and logging in NetworkInterface. - Enhance NicObservation for detailed NMNE event monitoring. - Add nmne_config options to simulation settings for customizable NMNE capturing. - Update documentation and tests for new NMNE features and simulation config. --- CHANGELOG.md | 6 +- .../network/network_interfaces.rst | 11 +- .../config/_package_data/example_config.yaml | 4 + .../example_config_2_rl_agents.yaml | 4 + src/primaite/game/agent/observations.py | 52 +++++++- src/primaite/game/game.py | 4 + .../simulator/network/hardware/base.py | 102 +++++++++++++-- .../network/hardware/nodes/host/host_node.py | 1 + src/primaite/simulator/network/nmne.py | 46 +++++++ .../network/test_capture_nmne.py | 120 ++++++++++++++++++ 10 files changed, 333 insertions(+), 17 deletions(-) create mode 100644 src/primaite/simulator/network/nmne.py create mode 100644 tests/integration_tests/network/test_capture_nmne.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 01e45d2e..40ac6535 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,8 @@ SessionManager. - `AirSpace` class to simulate wireless communications, managing wireless interfaces and facilitating the transmission of frames within specified frequencies. - `AirSpaceFrequency` enum for defining standard wireless frequencies, including 2.4 GHz and 5 GHz bands, to support realistic wireless network simulations. - `WirelessRouter` class, extending the `Router` class, to incorporate wireless networking capabilities alongside traditional wired connections. This class allows the configuration of wireless access points with specific IP settings and operating frequencies. - +- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events. +- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE". ### Changed - Integrated the RouteTable into the Routers frame processing. @@ -94,7 +95,8 @@ SessionManager. - Refactored all tests to utilise new `Node` subclasses (`Computer`, `Server`, `Router`, `Switch`) instead of creating generic `Node` instances and manually adding network interfaces. This change aligns test setups more closely with the intended use cases and hierarchies within the network simulation framework. - Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios. - **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules. - +- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them. +- Integration of NMNE capturing functionality within the `NicObservation` class. ### Removed - Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol` diff --git a/docs/source/simulation_components/network/network_interfaces.rst b/docs/source/simulation_components/network/network_interfaces.rst index 9e1ad80a..c74b54ae 100644 --- a/docs/source/simulation_components/network/network_interfaces.rst +++ b/docs/source/simulation_components/network/network_interfaces.rst @@ -65,9 +65,14 @@ Network Interface Classes **NetworkInterface (Base Layer)** -Abstract base class defining core interface properties like MAC address, speed, MTU. -Requires subclasses implement key methods like send/receive frames, enable/disable interface. -Establishes universal network interface capabilities. +- Abstract base class defining core interface properties like MAC address, speed, MTU. +- Requires subclasses implement key methods like send/receive frames, enable/disable interface. +- Establishes universal network interface capabilities. +- Malicious Network Events Monitoring: + + * Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns. + * Integrates NMNE detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies. + * Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms. **WiredNetworkInterface (Connection Type Layer)** diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f85baf10..a72ebeca 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -583,6 +583,10 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - ref: router_1 diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 93019c9d..12461547 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -963,6 +963,10 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - ref: router_1 diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index dfee2543..1d8799fd 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -8,6 +8,7 @@ from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import CAPTURE_NMNE _LOGGER = getLogger(__name__) @@ -346,7 +347,14 @@ class FolderObservation(AbstractObservation): class NicObservation(AbstractObservation): """Observation of a Network Interface Card (NIC) in the network.""" - default_observation: spaces.Space = {"nic_status": 0} + @property + def default_observation(self) -> Dict: + """The default NIC observation dict.""" + data = {"nic_status": 0} + + if CAPTURE_NMNE: + data.update({"nmne": {"inbound": 0, "outbound": 0}}) + return data def __init__(self, where: Optional[Tuple[str]] = None) -> None: """Initialise NIC observation. @@ -360,6 +368,29 @@ class NicObservation(AbstractObservation): super().__init__() self.where: Optional[Tuple[str]] = where + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (1-5 events). + - 2: Moderate number of MNEs (6-10 events). + - 3: High number of MNEs (more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > 10: + return 3 + elif nmne_count > 5: + return 2 + elif nmne_count > 0: + return 1 + return 0 + def observe(self, state: Dict) -> Dict: """Generate observation based on the current state of the simulation. @@ -371,15 +402,30 @@ class NicObservation(AbstractObservation): if self.where is None: return self.default_observation nic_state = access_from_nested_dict(state, self.where) + if nic_state is NOT_PRESENT_IN_STATE: return self.default_observation else: - return {"nic_status": 1 if nic_state["enabled"] else 2} + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2, "nmne": {}} + if CAPTURE_NMNE: + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) + obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) + return obs_dict @property def space(self) -> spaces.Space: """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"nic_status": spaces.Discrete(3)}) + return spaces.Dict( + { + "nic_status": spaces.Discrete(3), + "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), + } + ) @classmethod def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ed98accd..1f5dc8fa 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,6 +17,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -426,4 +427,7 @@ class PrimaiteGame: game.simulation.set_original_state() + # Set the NMNE capture config + set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {})) + return game diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index fa135674..c0e69e60 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -17,6 +17,15 @@ from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.nmne import ( + CAPTURE_BY_DIRECTION, + CAPTURE_BY_IP_ADDRESS, + CAPTURE_BY_KEYWORD, + CAPTURE_BY_PORT, + CAPTURE_BY_PROTOCOL, + CAPTURE_NMNE, + NMNE_CAPTURE_KEYWORDS, +) from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture @@ -88,6 +97,8 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." + nmne: Dict = Field(default_factory=lambda: {}) + def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -111,27 +122,99 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) + state.update({"nmne": self.nmne}) return state def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" super().reset_component_for_episode(episode) + self.nmne = {} if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() self.enable() - @abstractmethod + # @abstractmethod def enable(self): """Enable the interface.""" pass - @abstractmethod + # @abstractmethod def disable(self): """Disable the interface.""" pass - @abstractmethod + def _capture_nmne(self, frame: Frame, inbound: bool = True): + """ + Processes and captures network frame data based on predefined global NMNE settings. + + This method updates the NMNE structure with counts of malicious network events based on the frame content and + direction. The structure is dynamically adjusted according to the enabled capture settings. + + :param frame: The network frame to process, containing IP, TCP/UDP, and payload information. + :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. + """ + # Exit function if NMNE capturing is disabled + if not CAPTURE_NMNE: + return + + # Initialise basic frame data variables + direction = "inbound" if inbound else "outbound" # Direction of the traffic + ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP + protocol = frame.ip.protocol.name # Network protocol used in the frame + + # Initialise port variable; will be determined based on protocol type + port = None + + # Determine the source or destination port based on the protocol (TCP/UDP) + if frame.tcp: + port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + elif frame.udp: + port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + + # Convert frame payload to string for keyword checking + frame_str = str(frame.payload) + + # Proceed only if any NMNE keyword is present in the frame payload + if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS): + # Start with the root of the NMNE capture structure + current_level = self.nmne + + # Update NMNE structure based on enabled settings + if CAPTURE_BY_DIRECTION: + # Set or get the dictionary for the current direction + current_level = current_level.setdefault("direction", {}) + current_level = current_level.setdefault(direction, {}) + + if CAPTURE_BY_IP_ADDRESS: + # Set or get the dictionary for the current IP address + current_level = current_level.setdefault("ip_address", {}) + current_level = current_level.setdefault(ip_address, {}) + + if CAPTURE_BY_PROTOCOL: + # Set or get the dictionary for the current protocol + current_level = current_level.setdefault("protocol", {}) + current_level = current_level.setdefault(protocol, {}) + + if CAPTURE_BY_PORT: + # Set or get the dictionary for the current port + current_level = current_level.setdefault("port", {}) + current_level = current_level.setdefault(port, {}) + + # Ensure 'KEYWORD' level is present in the structure + keyword_level = current_level.setdefault("keywords", {}) + + # Increment the count for detected keywords in the payload + if CAPTURE_BY_KEYWORD: + for keyword in NMNE_CAPTURE_KEYWORDS: + if keyword in frame_str: + # Update the count for each keyword found + keyword_level[keyword] = keyword_level.get(keyword, 0) + 1 + else: + # Increment a generic counter if keyword capturing is not enabled + keyword_level["*"] = keyword_level.get("*", 0) + 1 + + # @abstractmethod def send_frame(self, frame: Frame) -> bool: """ Attempts to send a network frame through the interface. @@ -139,9 +222,9 @@ class NetworkInterface(SimComponent, ABC): :param frame: The network frame to be sent. :return: A boolean indicating whether the frame was successfully sent. """ - pass + self._capture_nmne(frame, inbound=False) - @abstractmethod + # @abstractmethod def receive_frame(self, frame: Frame) -> bool: """ Receives a network frame on the interface. @@ -149,7 +232,7 @@ class NetworkInterface(SimComponent, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + self._capture_nmne(frame, inbound=True) def __str__(self) -> str: """ @@ -263,6 +346,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame to be sent. :return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link. """ + super().send_frame(frame) if self.enabled: frame.set_sent_timestamp() self.pcap.capture_outbound(frame) @@ -279,7 +363,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + return super().receive_frame(frame) class Layer3Interface(BaseModel, ABC): @@ -409,7 +493,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): except AttributeError: pass - # @abstractmethod + @abstractmethod def receive_frame(self, frame: Frame) -> bool: """ Receives a network frame on the network interface. @@ -417,7 +501,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + return super().receive_frame(frame) class Link(SimComponent): diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 3f34f736..6ecd6733 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -248,6 +248,7 @@ class NIC(IPWiredNetworkInterface): accept_frame = True if accept_frame: + super().receive_frame(frame) self._connected_node.receive_frame(frame=frame, from_network_interface=self) return True return False diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py new file mode 100644 index 00000000..d4c40631 --- /dev/null +++ b/src/primaite/simulator/network/nmne.py @@ -0,0 +1,46 @@ +from typing import Dict, Final, List + +CAPTURE_NMNE: bool = True +"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True.""" + +NMNE_CAPTURE_KEYWORDS: List[str] = [] +"""List of keywords to identify malicious network events.""" + +CAPTURE_BY_DIRECTION: Final[bool] = True +"""Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" +CAPTURE_BY_IP_ADDRESS: Final[bool] = False +"""Flag to determine if captures should be organized by source or destination IP address.""" +CAPTURE_BY_PROTOCOL: Final[bool] = False +"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP).""" +CAPTURE_BY_PORT: Final[bool] = False +"""Flag to determine if captures should be organized by source or destination port.""" +CAPTURE_BY_KEYWORD: Final[bool] = False +"""Flag to determine if captures should be filtered and categorised based on specific keywords.""" + + +def set_nmne_config(nmne_config: Dict): + """ + Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary. + + This function updates global settings related to NMNE capture, including whether to capture NMNEs and what + keywords to use for identifying NMNEs. + + The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary, + and maintains type integrity by checking the types of the provided values. + + :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include: + "capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings) + to specify keywords for NMNE identification. + """ + global NMNE_CAPTURE_KEYWORDS + global CAPTURE_NMNE + + # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect + CAPTURE_NMNE = nmne_config.get("capture_nmne", False) + if not isinstance(CAPTURE_NMNE, bool): + CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean + + # Update the NMNE capture keywords, appending new keywords if provided + NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", []) + if not isinstance(NMNE_CAPTURE_KEYWORDS, list): + NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py new file mode 100644 index 00000000..85ac23e8 --- /dev/null +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -0,0 +1,120 @@ +from primaite.game.agent.observations import NicObservation +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient + + +def test_capture_nmne(uc2_network): + """ + Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured. + + This test involves a web server querying a database server and checks if the MNEs are captured + based on predefined keywords in the network configuration. Specifically, it checks the capture + of the "DELETE" SQL command as a malicious network event. + """ + web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa + db_client.connect() + + db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa + + web_server_nic = web_server.network_interface[1] + db_server_nic = db_server.network_interface[1] + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Assert that initially, there are no captured MNEs on both web and database servers + assert web_server_nic.describe_state()["nmne"] == {} + assert db_server_nic.describe_state()["nmne"] == {} + + # Perform a "SELECT" query + db_client.query("SELECT") + + # Check that it does not trigger an MNE capture. + assert web_server_nic.describe_state()["nmne"] == {} + assert db_server_nic.describe_state()["nmne"] == {} + + # Perform a "DELETE" query + db_client.query("DELETE") + + # Check that the web server's outbound interface and the database server's inbound interface register the MNE + assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "SELECT" query + db_client.query("SELECT") + + # Check that no additional MNEs are captured + assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "DELETE" query + db_client.query("DELETE") + + # Check that the web server and database server interfaces register an additional MNE + assert web_server_nic.describe_state()["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}} + assert db_server_nic.describe_state()["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}} + + +def test_capture_nmne_observations(uc2_network): + """ + Tests the NicObservation class's functionality within a simulated network environment. + + This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the + number of MNEs detected based on network activities over multiple iterations. + + The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update + and accuracy of the observation space related to network interface conditions. It confirms that the + observed NIC states match expected MNE activity levels. + """ + # Initialise a new Simulation instance and assign the test network to it. + sim = Simulation() + sim.network = uc2_network + + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client.connect() + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Define observations for the NICs of the database and web servers + db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1]) + web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1]) + + # Iterate through a set of test cases to simulate multiple DELETE queries + for i in range(1, 20): + # Perform a "DELETE" query each iteration + db_client.query("DELETE") + + # Observe the current state of NMNEs from the NICs of both the database and web servers + db_nic_obs = db_server_nic_obs.observe(sim.describe_state())["nmne"] + web_nic_obs = web_server_nic_obs.observe(sim.describe_state())["nmne"] + + # Define expected NMNE values based on the iteration count + if i > 10: + expected_nmne = 3 # High level of detected MNEs after 10 iterations + elif i > 5: + expected_nmne = 2 # Moderate level after more than 5 iterations + elif i > 0: + expected_nmne = 1 # Low level detected after just starting + else: + expected_nmne = 0 # No MNEs detected + + # Assert that the observed NMNEs match the expected values for both NICs + assert web_nic_obs["outbound"] == expected_nmne + assert db_nic_obs["inbound"] == expected_nmne From f933341df521feaca5e494bf739833b86d75ab28 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 23 Feb 2024 10:06:48 +0000 Subject: [PATCH 05/18] eod commit --- src/primaite/game/game.py | 10 ---------- src/primaite/simulator/core.py | 8 ++++++-- src/primaite/simulator/network/container.py | 8 ++++---- src/primaite/simulator/network/hardware/base.py | 14 +++++++------- .../simulator/network/hardware/nodes/router.py | 10 +++++----- src/primaite/simulator/sim_container.py | 4 ++-- 6 files changed, 24 insertions(+), 30 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index bd7ed2cd..72ad01e7 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -160,16 +160,6 @@ class PrimaiteGame: return True return False - def reset(self) -> None: # TODO: deprecated - remove me - """Reset the game, this will reset the simulation.""" - self.episode_counter += 1 - self.step_counter = 0 - _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(episode=self.episode_counter) - for agent in self.agents: - agent.reward_function.total_reward = 0.0 - agent.reset_agent_for_episode() - def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index e21ce9eb..b9188bf0 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -160,8 +160,12 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" + def setup_for_episode(self, episode: int): + """ + Perform any additional setup on this component that can't happen during __init__. + + For instance, some components may require for the entire network to exist before some configuration can be set. + """ pass def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 48205bbd..080a1004 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,12 +45,12 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.reset_component_for_episode(episode) + node.setup_for_episode(episode) for link in self.links.values(): - link.reset_component_for_episode(episode) + link.setup_for_episode(episode) for node in self.nodes.values(): node.power_on() @@ -171,7 +171,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.reset_component_for_episode() + link.setup_for_episode() def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 67ac42c8..e2a90db1 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,9 +123,9 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() @@ -1011,19 +1011,19 @@ class Node(SimComponent): self.session_manager.software_manager = self.software_manager self._install_system_software() - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) # Reset File System - self.file_system.reset_component_for_episode(episode) + self.file_system.setup_for_episode(episode) # Reset all Nics for nic in self.nics.values(): - nic.reset_component_for_episode(episode) + nic.setup_for_episode(episode) for software in self.software_manager.software.values(): - software.reset_component_for_episode(episode) + software.setup_for_episode(episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index aa154ad9..887bc9be 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -743,16 +743,16 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() - self.acl.reset_component_for_episode(episode) - self.route_table.reset_component_for_episode(episode) + self.acl.setup_for_episode(episode) + self.route_table.setup_for_episode(episode) for i, nic in self.ethernet_ports.items(): - nic.reset_component_for_episode(episode) + nic.setup_for_episode(episode) self.enable_port(i) - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 18ed894c..bb6132a8 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,9 +21,9 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.reset_component_for_episode(episode) + self.network.setup_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() From 52677538a89f9f5c5ffa88b72e0b0e0415f85cc6 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 23 Feb 2024 15:12:46 +0000 Subject: [PATCH 06/18] #2238 - Tidied up code, added more docstrings, and implemented suggestions from PR. --- .../network/network_interfaces.rst | 2 +- src/primaite/game/agent/observations.py | 4 +--- .../simulator/network/hardware/base.py | 18 ++++++++++++------ .../network/hardware/nodes/host/host_node.py | 6 +----- src/primaite/simulator/network/nmne.py | 1 + 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/docs/source/simulation_components/network/network_interfaces.rst b/docs/source/simulation_components/network/network_interfaces.rst index c74b54ae..2bb8dda4 100644 --- a/docs/source/simulation_components/network/network_interfaces.rst +++ b/docs/source/simulation_components/network/network_interfaces.rst @@ -71,7 +71,7 @@ Network Interface Classes - Malicious Network Events Monitoring: * Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns. - * Integrates NMNE detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies. + * Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies. * Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms. **WiredNetworkInterface (Connection Type Layer)** diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 1d8799fd..7ccc3f11 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -352,8 +352,6 @@ class NicObservation(AbstractObservation): """The default NIC observation dict.""" data = {"nic_status": 0} - if CAPTURE_NMNE: - data.update({"nmne": {"inbound": 0, "outbound": 0}}) return data def __init__(self, where: Optional[Tuple[str]] = None) -> None: @@ -407,7 +405,7 @@ class NicObservation(AbstractObservation): return self.default_observation else: obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2, "nmne": {}} - if CAPTURE_NMNE: + if CAPTURE_NMNE and nic_state.get("nmne"): direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) inbound_count = inbound_keywords.get("*", 0) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index c0e69e60..b22bea25 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -98,6 +98,7 @@ class NetworkInterface(SimComponent, ABC): "A PacketCapture instance for capturing and analysing packets passing through this interface." nmne: Dict = Field(default_factory=lambda: {}) + "A dict containing details of the number of malicious network events captured." def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -122,7 +123,6 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) - state.update({"nmne": self.nmne}) return state def reset_component_for_episode(self, episode: int): @@ -134,23 +134,29 @@ class NetworkInterface(SimComponent, ABC): self.pcap.setup_logger() self.enable() - # @abstractmethod + @abstractmethod def enable(self): """Enable the interface.""" pass - # @abstractmethod + @abstractmethod def disable(self): """Disable the interface.""" pass - def _capture_nmne(self, frame: Frame, inbound: bool = True): + def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None: """ Processes and captures network frame data based on predefined global NMNE settings. This method updates the NMNE structure with counts of malicious network events based on the frame content and direction. The structure is dynamically adjusted according to the enabled capture settings. + .. note:: + While there is a lot of logic in this code that defines a multi-level hierarchical NMNE structure, + most of it is unused for now as a result of all `CAPTURE_BY_<>` variables in + ``primaite.simulator.network.nmne`` being hardcoded and set as final. Once they're 'released' and made + configurable, this function will be updated to properly explain the dynamic data structure. + :param frame: The network frame to process, containing IP, TCP/UDP, and payload information. :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. """ @@ -214,7 +220,7 @@ class NetworkInterface(SimComponent, ABC): # Increment a generic counter if keyword capturing is not enabled keyword_level["*"] = keyword_level.get("*", 0) + 1 - # @abstractmethod + @abstractmethod def send_frame(self, frame: Frame) -> bool: """ Attempts to send a network frame through the interface. @@ -224,7 +230,7 @@ class NetworkInterface(SimComponent, ABC): """ self._capture_nmne(frame, inbound=False) - # @abstractmethod + @abstractmethod def receive_frame(self, frame: Frame) -> bool: """ Receives a network frame on the interface. diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 6ecd6733..8e104924 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -205,11 +205,7 @@ class NIC(IPWiredNetworkInterface): state = super().describe_state() # Update the state with NIC-specific information - state.update( - { - "wake_on_lan": self.wake_on_lan, - } - ) + state.update({"wake_on_lan": self.wake_on_lan, "nmne": self.nmne}) return state diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index d4c40631..87839712 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -6,6 +6,7 @@ CAPTURE_NMNE: bool = True NMNE_CAPTURE_KEYWORDS: List[str] = [] """List of keywords to identify malicious network events.""" +# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically CAPTURE_BY_DIRECTION: Final[bool] = True """Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" CAPTURE_BY_IP_ADDRESS: Final[bool] = False From c115095157f27d6d7480df430c2c83e50078184d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 25 Feb 2024 16:17:12 +0000 Subject: [PATCH 07/18] Fix router from config using wrong method --- src/primaite/simulator/network/hardware/nodes/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 887bc9be..6bf80b3c 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -1016,7 +1016,7 @@ class Router(Node): # Add the router's default ACL rules from the config. if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): - new.add_rule( + new.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else Port[p], dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], From 994dbc3501b7c50584322cf3bba9db7aa7e0b77d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 25 Feb 2024 17:44:41 +0000 Subject: [PATCH 08/18] Finalise the refactor. It works well now. --- .../config/_package_data/example_config.yaml | 5 +- src/primaite/game/game.py | 12 +++- src/primaite/notebooks/uc2_demo.ipynb | 66 +++++++++---------- src/primaite/session/environment.py | 7 +- src/primaite/simulator/network/container.py | 6 +- .../simulator/network/hardware/base.py | 10 +-- .../network/hardware/nodes/network/router.py | 2 +- src/primaite/simulator/sim_container.py | 2 +- .../system/services/web_server/web_server.py | 2 +- src/primaite/simulator/system/software.py | 7 +- 10 files changed, 65 insertions(+), 54 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f85baf10..a32696c7 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -652,12 +652,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 02d36c8a..f5649589 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -185,6 +185,10 @@ class PrimaiteGame: """Close the game, this will close the simulation.""" return NotImplemented + def setup_for_episode(self, episode: int) -> None: + """Perform any final configuration of components to make them ready for the game to start.""" + self.simulation.setup_for_episode(episode=episode) + @classmethod def from_config(cls, cfg: Dict) -> "PrimaiteGame": """Create a PrimaiteGame object from a config dictionary. @@ -258,7 +262,9 @@ class PrimaiteGame: new_service = new_node.software_manager.software[service_type] game.ref_map_services[service_ref] = new_service.uuid else: - _LOGGER.warning(f"service type not found {service_type}") + msg = f"Configuration contains an invalid service type: {service_type}" + _LOGGER.error(msg) + raise ValueError(msg) # service-dependent options if service_type == "DNSClient": if "options" in service_cfg: @@ -297,7 +303,9 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: - _LOGGER.warning(f"application type not found {application_type}") + msg = f"Configuration contains an invalid application type: {application_type}" + _LOGGER.error(msg) + raise ValueError(msg) if application_type == "DataManipulationBot": if "options" in application_cfg: diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index c4fe4c9a..460e3d27 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -335,9 +335,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", @@ -347,9 +345,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# Imports\n", @@ -372,9 +368,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# create the env\n", @@ -385,10 +379,10 @@ " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", - "game = PrimaiteGame.from_config(cfg)\n", - "env = PrimaiteGymEnv(game = game)\n", - "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", - "env.agent.flatten_obs = False\n", + " # don't flatten observations so that we can see what is going on\n", + " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", "obs, info = env.reset()\n", "print('env created successfully')\n", "pprint(obs)" @@ -422,9 +416,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "for step in range(35):\n", @@ -442,9 +434,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "pprint(obs['NODES'])" @@ -460,9 +450,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", @@ -488,9 +476,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", @@ -515,9 +501,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", @@ -540,9 +524,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "env.step(13) # Patch the database\n", @@ -582,6 +564,22 @@ "obs['ACL']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reset the cell, you can rerun the other cells to verify that the attack works the same every episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.reset()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -592,7 +590,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -606,9 +604,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 2 } diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index ad770f8f..bab81253 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -74,6 +74,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" ) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) self.agent = self.game.rl_agents[0] self.episode_counter += 1 state = self.game.get_sim_state() @@ -97,12 +98,12 @@ class PrimaiteGymEnv(gymnasium.Env): def _get_obs(self) -> ObsType: """Return the current observation.""" - if not self.agent.flatten_obs: - return self.agent.observation_manager.current_observation - else: + if self.agent.flatten_obs: unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + else: + return self.agent.observation_manager.current_observation class PrimaiteRayEnv(gymnasium.Env): diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index c3ad84c3..b5a16430 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -48,9 +48,9 @@ class Network(SimComponent): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.setup_for_episode(episode) + node.setup_for_episode(episode=episode) for link in self.links.values(): - link.setup_for_episode(episode) + link.setup_for_episode(episode=episode) for node in self.nodes.values(): node.power_on() @@ -172,7 +172,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.setup_for_episode() + link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here. def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 771c3397..1b6d611e 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -90,7 +90,7 @@ class NetworkInterface(SimComponent, ABC): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().setup_for_episode(episode) + super().setup_for_episode(episode=episode) if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() @@ -643,17 +643,17 @@ class Node(SimComponent): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().setup_for_episode(episode) + super().setup_for_episode(episode=episode) # Reset File System - self.file_system.setup_for_episode(episode) + self.file_system.setup_for_episode(episode=episode) # Reset all Nics for network_interface in self.network_interfaces.values(): - network_interface.setup_for_episode(episode) + network_interface.setup_for_episode(episode=episode) for software in self.software_manager.software.values(): - software.setup_for_episode(episode) + software.setup_for_episode(episode=episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index c299dfb7..3111a153 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1078,7 +1078,7 @@ class Router(NetworkNode): for i, _ in self.network_interface.items(): self.enable_port(i) - super().setup_for_episode(episode) + super().setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index bb6132a8..a2285d92 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -23,7 +23,7 @@ class Simulation(SimComponent): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.setup_for_episode(episode) + self.network.setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5e4a6f6e..ce29a2f9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -118,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: + except Exception: # TODO: refactor this. Likely to cause silent bugs. # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 56a1e3d1..8864659c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -3,7 +3,7 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +if TYPE_CHECKING: + from primaite.simulator.system.core.software_manager import SoftwareManager + class SoftwareType(Enum): """ @@ -84,7 +87,7 @@ class Software(SimComponent): "The count of times the software has been scanned, defaults to 0." revealed_to_red: bool = False "Indicates if the software has been revealed to red agent, defaults is False." - software_manager: Any = None + software_manager: "SoftwareManager" = None "An instance of Software Manager that is used by the parent node." sys_log: SysLog = None "An instance of SysLog that is used by the parent node." From 63c9a36c30adf38716759526968067ae990f1fdc Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 25 Feb 2024 18:36:20 +0000 Subject: [PATCH 09/18] Fix typos --- src/primaite/notebooks/uc2_demo.ipynb | 2 +- src/primaite/simulator/system/services/ntp/ntp_client.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 460e3d27..7c90a885 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -568,7 +568,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Reset the cell, you can rerun the other cells to verify that the attack works the same every episode." + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode." ] }, { diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 1e9dc139..ad00065c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -1,6 +1,6 @@ from datetime import datetime from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket @@ -54,7 +54,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: List[Port] = Port.NTP, + dest_port: Port = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. From e5982c4599b07ef5cf994218f4323d1105f65bc7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 10:26:28 +0000 Subject: [PATCH 10/18] Change agents list in game object to dictionary --- .../example_config_2_rl_agents.yaml | 446 +++++++++++------- src/primaite/game/game.py | 18 +- .../training_example_ray_multi_agent.ipynb | 9 +- .../training_example_ray_single_agent.ipynb | 2 +- .../notebooks/training_example_sb3.ipynb | 11 +- src/primaite/session/environment.py | 52 +- tests/conftest.py | 2 +- tests/integration_tests/game_configuration.py | 16 +- 8 files changed, 331 insertions(+), 225 deletions(-) diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 93019c9d..1ccd7b38 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -10,6 +10,8 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true game: @@ -36,9 +38,9 @@ agents: - type: NODE_APPLICATION_EXECUTE options: nodes: - - node_ref: client_2 + - node_name: client_2 applications: - - application_ref: client_2_web_browser + - application_name: WebBrowser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -54,6 +56,31 @@ agents: frequency: 4 variance: 3 + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + reward_function: + reward_components: + - type: DUMMY + + + + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -72,7 +99,7 @@ agents: - type: NODE_OS_SCAN options: nodes: - - node_ref: client_1 + - node_name: client_1 applications: - application_name: DataManipulationBot - node_name: client_2 @@ -104,25 +131,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -137,23 +160,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -184,10 +207,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -242,25 +265,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -271,22 +294,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -303,63 +326,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -407,123 +430,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... - + flatten_obs: true - ref: defender_2 team: BLUE @@ -537,25 +585,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -570,23 +614,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -617,10 +661,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -675,25 +719,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -704,22 +748,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -736,63 +780,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -840,122 +884,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... + flatten_obs: true @@ -1032,12 +1102,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -1089,10 +1160,14 @@ simulation: - ref: data_manipulation_bot type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ services: - ref: client_1_dns_client type: DNSClient @@ -1109,6 +1184,13 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f5649589..8edf70ea 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -79,11 +79,11 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - self.agents: List[AbstractAgent] = [] - """List of agents.""" + self.agents: Dict[str, AbstractAgent] = {} + """Mapping from agent name to agent object.""" - self.rl_agents: List[ProxyAgent] = [] - """Subset of agent list including only the reinforcement learning agents.""" + self.rl_agents: Dict[str, ProxyAgent] = {} + """Subset of agents which are intended for reinforcement learning.""" self.step_counter: int = 0 """Current timestep within the episode.""" @@ -144,7 +144,7 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent in self.agents: + for name, agent in self.agents.items(): agent.update_observation(state) agent.update_reward(state) agent.reward_function.total_reward += agent.reward_function.current_reward @@ -158,7 +158,7 @@ class PrimaiteGame: """ agent_actions = {} - for agent in self.agents: + for name, agent in self.agents.items(): obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) @@ -396,7 +396,6 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -405,8 +404,7 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) - game.rl_agents.append(new_agent) + game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], @@ -415,8 +413,8 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) else: _LOGGER.warning(f"agent type {agent_type} not found") + game.agents[agent_cfg["ref"]] = new_agent return game diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index 0d4b6d0e..4ef02443 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -60,7 +60,7 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", " .rollouts(num_rollout_workers=0)\n", " .training(train_batch_size=128)\n", " )\n" @@ -88,6 +88,13 @@ " param_space=config\n", ").fit()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index ea006ae9..3c27bdc6 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -54,7 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "env_config = {\"cfg\":cfg}\n", + "env_config = cfg\n", "\n", "config = (\n", " PPOConfig()\n", diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index 164142b2..0472854e 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -27,9 +27,7 @@ "outputs": [], "source": [ "with open(example_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - "game = PrimaiteGame.from_config(cfg)" + " cfg = yaml.safe_load(f)\n" ] }, { @@ -76,6 +74,13 @@ "source": [ "model.save(\"deleteme\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index bab81253..f8dbab9d 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType @@ -25,12 +25,17 @@ class PrimaiteGymEnv(gymnasium.Env): """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) """Current game.""" - self.agent: ProxyAgent = self.game.rl_agents[0] - """The agent within the game that is controlled by the RL algorithm.""" + self._agent_name = next(iter(self.game.rl_agents)) + """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" self.episode_counter: int = 0 """Current episode number.""" + @property + def agent(self) -> ProxyAgent: + """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" + return self.game.rl_agents[self._agent_name] + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" # make ProxyAgent store the action chosen my the RL policy @@ -71,11 +76,10 @@ class PrimaiteGymEnv(gymnasium.Env): """Reset the environment.""" print( f"Resetting environment, episode {self.episode_counter}, " - f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" + f"avg. reward: {self.agent.reward_function.total_reward}" ) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) self.game.setup_for_episode(episode=self.episode_counter) - self.agent = self.game.rl_agents[0] self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) @@ -112,11 +116,10 @@ class PrimaiteRayEnv(gymnasium.Env): def __init__(self, env_config: Dict) -> None: """Initialise the environment. - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict """ - self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) + self.env = PrimaiteGymEnv(game_config=env_config) self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -138,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :type env_config: Dict """ - self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"]) + self.game_config: Dict = env_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) """Reference to the primaite game""" - self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} - """List of all possible agents in the environment. This list should not change!""" - self._agent_ids = list(self.agents.keys()) + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + self.episode_counter: int = 0 + """Current episode number.""" self.terminateds = set() self.truncateds = set() @@ -159,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) super().__init__() + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -182,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -196,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {"agent_actions": agent_actions} + infos = {name: {} for name, _ in self.agents.items()} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: @@ -222,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} - for name, agent in self.agents.items(): + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) return obs diff --git a/tests/conftest.py b/tests/conftest.py index 5084c339..83ac9559 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -510,6 +510,6 @@ def game_and_agent(): reward_function=reward_function, ) - game.agents.append(test_agent) + game.agents["test_agent"] = test_agent return (game, test_agent) diff --git a/tests/integration_tests/game_configuration.py b/tests/integration_tests/game_configuration.py index 3bd870e3..f3dc51bd 100644 --- a/tests/integration_tests/game_configuration.py +++ b/tests/integration_tests/game_configuration.py @@ -42,20 +42,20 @@ def test_example_config(): assert len(game.agents) == 4 # red, blue and 2 green agents # green agent 1 - assert game.agents[0].agent_name == "client_2_green_user" - assert isinstance(game.agents[0], RandomAgent) + assert "client_2_green_user" in game.agents + assert isinstance(game.agents["client_2_green_user"], RandomAgent) # green agent 2 - assert game.agents[1].agent_name == "client_1_green_user" - assert isinstance(game.agents[1], RandomAgent) + assert "client_1_green_user" in game.agents + assert isinstance(game.agents["client_1_green_user"], RandomAgent) # red agent - assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot" - assert isinstance(game.agents[2], DataManipulationAgent) + assert "client_1_data_manipulation_red_bot" in game.agents + assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) # blue agent - assert game.agents[3].agent_name == "defender" - assert isinstance(game.agents[3], ProxyAgent) + assert "defender" in game.agents + assert isinstance(game.agents["defender"], ProxyAgent) network: Network = game.simulation.network From ccb10f1160cee454216c205e15c08d2ecc0c02d7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 11:02:37 +0000 Subject: [PATCH 11/18] Update docs based on reset refactor --- CHANGELOG.md | 1 + docs/index.rst | 1 + docs/source/environment.rst | 10 ++++++++++ docs/source/game_layer.rst | 5 +++++ docs/source/primaite_session.rst | 5 +++++ 5 files changed, 22 insertions(+) create mode 100644 docs/source/environment.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index 01e45d2e..d2a582be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Made environment reset completely recreate the game object. - Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack. - Changed the data manipulation scenario to include a second green agent on client 1. - Refactored actions and observations to be configurable via object name, instead of UUID. diff --git a/docs/index.rst b/docs/index.rst index 9eae8adc..08e0ac21 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/simulation source/game_layer source/config + source/environment .. toctree:: :caption: Developer information: diff --git a/docs/source/environment.rst b/docs/source/environment.rst new file mode 100644 index 00000000..87e7f060 --- /dev/null +++ b/docs/source/environment.rst @@ -0,0 +1,10 @@ +RL Environments +*************** + +RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs: + +* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation `_. +* Ray Single agent API - For training a single Ray RLLib agent +* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. + +There is a Jupyter notebook which demonstrates integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index cdae17dd..1f2921fe 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface PrimAITE Session ^^^^^^^^^^^^^^^ +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + ``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types. Agents diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 706397b6..87a3f03d 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -4,6 +4,11 @@ .. _run a primaite session: +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + Run a PrimAITE Session ====================== From a5043a8fbe01b38699cedf677129a17bec32655d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 12:15:53 +0000 Subject: [PATCH 12/18] Modify tests based on refactoring --- src/primaite/session/session.py | 6 ++-- src/primaite/simulator/network/airspace.py | 5 --- .../network/hardware/nodes/host/host_node.py | 5 --- .../hardware/nodes/network/firewall.py | 18 ---------- .../network/hardware/nodes/network/switch.py | 6 ---- .../hardware/nodes/network/wireless_router.py | 3 -- .../assets/configs/bad_primaite_session.yaml | 7 ++-- .../configs/eval_only_primaite_session.yaml | 9 ++--- tests/assets/configs/multi_agent_session.yaml | 10 +++--- .../assets/configs/test_primaite_session.yaml | 9 ++--- .../configs/train_only_primaite_session.yaml | 9 ++--- .../environments/test_sb3_environment.py | 3 +- .../_simulator/_domain/test_account.py | 19 ----------- .../_file_system/test_file_system.py | 31 ----------------- .../_simulator/_network/test_container.py | 34 ------------------- .../_red_applications/test_dos_bot.py | 28 --------------- 16 files changed, 28 insertions(+), 174 deletions(-) diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 5c663cfd..b8f80e95 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -101,11 +101,11 @@ class PrimaiteSession: # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayEnv(env_config=cfg) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayMARLEnv(env_config=cfg) elif sess.training_options.rl_framework == "SB3": - sess.env = PrimaiteGymEnv(game=game) + sess.env = PrimaiteGymEnv(game_config=cfg) sess.policy = PolicyABC.from_config(sess.training_options, session=sess) if agent_load_path: diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 724b8728..d264f751 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def enable(self): """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 3f34f736..329a5fa0 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -213,11 +213,6 @@ class NIC(IPWiredNetworkInterface): return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def receive_frame(self, frame: Frame) -> bool: """ Attempt to receive and process a network frame from the connected Link. diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 22effa2a..f2305652 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -109,24 +109,6 @@ class Firewall(Router): sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound" ) - self.set_original_state() - - def set_original_state(self): - """Set the original state for the Firewall.""" - super().set_original_state() - vals_to_include = { - "internal_port", - "external_port", - "dmz_port", - "internal_inbound_acl", - "internal_outbound_acl", - "dmz_inbound_acl", - "dmz_outbound_acl", - "external_inbound_acl", - "external_outbound_acl", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def describe_state(self) -> Dict: """ Describes the current state of the Firewall. diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 33e6ee9a..557ea287 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface): _connected_node: Optional[Switch] = None "The Switch to which the SwitchPort is connected." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index dd0b58d3..91833d6a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -122,8 +122,6 @@ class WirelessRouter(Router): self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) - self.set_original_state() - @property def wireless_access_point(self) -> WirelessAccessPoint: """ @@ -166,7 +164,6 @@ class WirelessRouter(Router): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured WAP {network_interface}") - self.set_original_state() self.wireless_access_point.frequency = frequency # Set operating frequency self.wireless_access_point.enable() # Re-enable the WAP with new settings diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 5bdc3273..c76aeef6 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -589,15 +589,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 8361e318..1cb59f87 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -593,15 +593,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -624,7 +625,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 87bd9d1c..b1b15372 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1043,16 +1043,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - - ref: database_server type: server @@ -1074,7 +1074,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 76190a64..e5f9d544 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -599,15 +599,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -630,7 +631,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5d004c7e..10e088d8 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -600,15 +600,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -631,7 +632,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 91cf5c1e..dc5d10e9 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -17,8 +17,7 @@ def test_sb3_compatibility(): with open(example_config_path(), "r") as f: cfg = yaml.safe_load(f) - game = PrimaiteGame.from_config(cfg) - gym = PrimaiteGymEnv(game=game) + gym = PrimaiteGymEnv(game_config=cfg) model = PPO("MlpPolicy", gym) model.learn(total_timesteps=1000) diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 695b15dd..786fe851 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -12,20 +12,6 @@ def account() -> Account: def test_original_state(account): """Test the original state - see if it resets properly""" - account.log_on() - account.log_off() - account.disable() - - state = account.describe_state() - assert state["num_logons"] is 1 - assert state["num_logoffs"] is 1 - assert state["num_group_changes"] is 0 - assert state["username"] is "Jake" - assert state["password"] is "totally_hashed_password" - assert state["account_type"] is AccountType.USER.value - assert state["enabled"] is False - - account.reset_component_for_episode(episode=1) state = account.describe_state() assert state["num_logons"] is 0 assert state["num_logoffs"] is 0 @@ -39,11 +25,6 @@ def test_original_state(account): account.log_off() account.disable() - account.log_on() - state = account.describe_state() - assert state["num_logons"] is 2 - - account.reset_component_for_episode(episode=2) state = account.describe_state() assert state["num_logons"] is 1 assert state["num_logoffs"] is 1 diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 2fe3f04c..4defc80c 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -185,37 +185,6 @@ def test_get_file(file_system): file_system.show(full=True) -def test_reset_file_system(file_system): - # file and folder that existed originally - file_system.create_file(file_name="test_file.zip") - file_system.create_folder(folder_name="test_folder") - - # create a new file - file_system.create_file(file_name="new_file.txt") - - # create a new folder - file_system.create_folder(folder_name="new_folder") - - # delete the file that existed originally - file_system.delete_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None - - # delete the folder that existed originally - file_system.delete_folder(folder_name="test_folder") - assert file_system.get_folder(folder_name="test_folder") is None - - # reset - file_system.reset_component_for_episode(episode=1) - - # deleted original file and folder should be back - assert file_system.get_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_folder(folder_name="test_folder") - - # new file and folder should be removed - assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None - assert file_system.get_folder(folder_name="new_folder") is None - - @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index bf79677e..2cfc3f11 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -44,40 +44,6 @@ def test_describe_state(network): assert len(state["links"]) is 6 -def test_reset_network(network): - """ - Test that the network is properly reset. - - TODO: make sure that once implemented - any installed/uninstalled services, processes, apps, - etc are also removed/reinstalled - - """ - state_before = network.describe_state() - - client_1: Computer = network.get_node_by_hostname("client_1") - server_1: Computer = network.get_node_by_hostname("server_1") - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - - client_1.power_off() - assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - server_1.power_off() - assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - assert network.describe_state() != state_before - - network.reset_component_for_episode(episode=1) - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - # don't worry if UUIDs change - a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) - b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) - assert a == b - - def test_creating_container(): """Check that we can create a network container""" net = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 1f28244d..4bfd28d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -27,34 +27,6 @@ def test_dos_bot_creation(dos_bot): assert dos_bot is not None -def test_dos_bot_reset(dos_bot): - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - - # should reset the relevant items - dos_bot.reset_component_for_episode(episode=0) - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - dos_bot.reset_component_for_episode(episode=1) - # should reset to the configured value - assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") - assert dos_bot.target_port is Port.HTTP - assert dos_bot.payload == "payload" - assert dos_bot.repeat is True - - def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node: Computer = dos_bot.parent assert dos_bot_node.operating_state is NodeOperatingState.ON From 2076b011ba8837f8e85e362a68c71b1010aa8e0a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 14:26:47 +0000 Subject: [PATCH 13/18] Put back default router rules --- src/primaite/simulator/network/hardware/nodes/network/router.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 3111a153..52f38eb6 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1039,6 +1039,8 @@ class Router(NetworkNode): self.connect_nic(network_interface) self.network_interface[i] = network_interface + self._set_default_acl() + def _install_system_software(self): """ Installs essential system software and network services on the router. From f9cc5af7aab3d822dcc23472f65c74a04ced4650 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 16:06:58 +0000 Subject: [PATCH 14/18] Not sure how this test was passing before --- .../_primaite/_simulator/_system/test_software.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index e77cd895..6f680012 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,12 +2,14 @@ from typing import Dict import pytest +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.software import Software, SoftwareHealthState +from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import IOSoftware, SoftwareHealthState -class TestSoftware(Software): +class TestSoftware(Service): def describe_state(self) -> Dict: pass @@ -15,7 +17,11 @@ class TestSoftware(Software): @pytest.fixture(scope="function") def software(file_system): return TestSoftware( - name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + name="TestSoftware", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="test_service"), + protocol=IPProtocol.TCP, ) From 33d2ecc26a4ac125607477c0a2846afa1b6fc728 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 16:58:43 +0000 Subject: [PATCH 15/18] Apply suggestions from code review. --- docs/source/environment.rst | 2 +- .../config/_package_data/example_config_2_rl_agents.yaml | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/source/environment.rst b/docs/source/environment.rst index 87e7f060..2b76572d 100644 --- a/docs/source/environment.rst +++ b/docs/source/environment.rst @@ -7,4 +7,4 @@ RL environments are the objects that directly interface with RL libraries such a * Ray Single agent API - For training a single Ray RLLib agent * Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. -There is a Jupyter notebook which demonstrates integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. +There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 1ccd7b38..c1e077be 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -1,11 +1,17 @@ training_config: rl_framework: RLLIB_multi_agent - # rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 256 + deterministic_eval: false n_agents: 2 agent_references: - defender_1 - defender_2 + io_settings: save_checkpoints: true checkpoint_interval: 5 From d55b6a5b48bf0faa6aeed6bd5ee94c65ab90912b Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Feb 2024 12:03:58 +0000 Subject: [PATCH 16/18] #2238 - Fixed the observations issue causing tests to fail --- src/primaite/game/agent/observations.py | 7 +++++-- src/primaite/simulator/network/hardware/base.py | 2 ++ .../simulator/network/hardware/nodes/host/host_node.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 7ccc3f11..82e11fe0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -351,6 +351,8 @@ class NicObservation(AbstractObservation): def default_observation(self) -> Dict: """The default NIC observation dict.""" data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"nmne": {"inbound": 0, "outbound": 0}}) return data @@ -404,8 +406,9 @@ class NicObservation(AbstractObservation): if nic_state is NOT_PRESENT_IN_STATE: return self.default_observation else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2, "nmne": {}} - if CAPTURE_NMNE and nic_state.get("nmne"): + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"nmne": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) inbound_count = inbound_keywords.get("*", 0) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index b22bea25..35c90d05 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,6 +123,8 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) + if CAPTURE_NMNE: + state.update({"nmne": self.nmne}) return state def reset_component_for_episode(self, episode: int): diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 8e104924..b48950b7 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -205,7 +205,7 @@ class NIC(IPWiredNetworkInterface): state = super().describe_state() # Update the state with NIC-specific information - state.update({"wake_on_lan": self.wake_on_lan, "nmne": self.nmne}) + state.update({"wake_on_lan": self.wake_on_lan}) return state From 63ea5478ab4fc39ffde3359983c332098fd318b6 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Feb 2024 13:56:19 +0000 Subject: [PATCH 17/18] #2238 - Updated uc2_demo.ipynb to explain the NMNE in observation space --- src/primaite/notebooks/uc2_demo.ipynb | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index c4fe4c9a..b1e12370 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -130,6 +130,9 @@ " - NETWORK_INTERFACES\n", " - \n", " - nic_status\n", + " - nmne\n", + " - inbound\n", + " - outbound\n", " - operating_status\n", "- LINKS\n", " - \n", @@ -220,6 +223,14 @@ "|1|ENABLED|\n", "|2|DISABLED|\n", "\n", + "NMNE (number of malicious network events) means, for inbound or outbound traffic, means:\n", + "|value|NMNEs|\n", + "|--|--|\n", + "|0|None|\n", + "|1|1 - 5|\n", + "|2|6 - 10|\n", + "|3|More than 10|\n", + "\n", "Link load has the following meaning:\n", "|load|percent utilisation|\n", "|--|--|\n", From 8730330f73bf5b38ade30bfc18c23ee3c5523367 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Feb 2024 10:14:31 +0000 Subject: [PATCH 18/18] Apply PR suggestions --- src/primaite/config/_package_data/example_config.yaml | 2 +- .../_package_data/example_config_2_rl_agents.yaml | 2 +- src/primaite/game/game.py | 10 ++++++---- .../simulator/network/hardware/nodes/network/router.py | 1 - .../simulator/system/services/web_server/web_server.py | 2 +- .../environments/test_sb3_environment.py | 1 - 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index a32696c7..ebee4980 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -14,7 +14,7 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false - save_pcap_logs: true + save_pcap_logs: false save_sys_logs: true diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index c1e077be..992c3a1a 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -16,7 +16,7 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false - save_pcap_logs: true + save_pcap_logs: false save_sys_logs: true diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8edf70ea..baa84d1d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -133,7 +133,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - agent_actions = self.apply_agent_actions() # noqa + self.apply_agent_actions() # Advance timestep self.advance_timestep() @@ -144,7 +144,7 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for name, agent in self.agents.items(): + for _, agent in self.agents.items(): agent.update_observation(state) agent.update_reward(state) agent.reward_function.total_reward += agent.reward_function.current_reward @@ -158,7 +158,7 @@ class PrimaiteGame: """ agent_actions = {} - for name, agent in self.agents.items(): + for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) @@ -414,7 +414,9 @@ class PrimaiteGame: agent_settings=agent_settings, ) else: - _LOGGER.warning(f"agent type {agent_type} not found") + msg(f"Configuration error: {agent_type} is not a valid agent type.") + _LOGGER.error(msg) + raise ValueError(msg) game.agents[agent_cfg["ref"]] = new_agent return game diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 52f38eb6..aa6eec3a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1076,7 +1076,6 @@ class Router(NetworkNode): :param episode: The episode number for which the router is being reset. """ self.software_manager.arp.clear() - # self.acl.reset_component_for_episode(episode) for i, _ in self.network_interface.items(): self.enable_port(i) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index ce29a2f9..5e7591e9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -118,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: # TODO: refactor this. Likely to cause silent bugs. + except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index dc5d10e9..c48ddbc9 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -11,7 +11,6 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv -# @pytest.mark.skip(reason="no way of currently testing this") def test_sb3_compatibility(): """Test that the Gymnasium environment can be used with an SB3 agent.""" with open(example_config_path(), "r") as f: