From 64b9ba3ecf2e1865e902917bd80d05bac70ab0bd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:21:03 +0000 Subject: [PATCH 01/14] Make environment reset reinstantiate the game --- src/primaite/game/game.py | 5 +--- .../notebooks/training_example_sb3.ipynb | 4 +-- src/primaite/session/environment.py | 27 ++++++++++++------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1fd0dc8b..091438ce 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -67,9 +67,6 @@ class PrimaiteGame: self.step_counter: int = 0 """Current timestep within the episode.""" - self.episode_counter: int = 0 - """Current episode number.""" - self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" @@ -163,7 +160,7 @@ class PrimaiteGame: return True return False - def reset(self) -> None: + def reset(self) -> None: # TODO: deprecated - remove me """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e5085c5e..164142b2 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -38,7 +38,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game)" + "gym = PrimaiteGymEnv(game_config=cfg)" ] }, { @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=1000)\n" + "model.learn(total_timesteps=10)\n" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a3831bc1..ad770f8f 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -18,11 +18,18 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame): + def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() - self.game: "PrimaiteGame" = game + self.game_config: Dict = game_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + """Current game.""" self.agent: ProxyAgent = self.game.rl_agents[0] + """The agent within the game that is controlled by the RL algorithm.""" + + self.episode_counter: int = 0 + """Current episode number.""" def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -45,13 +52,13 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), @@ -63,10 +70,12 @@ class PrimaiteGymEnv(gymnasium.Env): def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" print( - f"Resetting environment, episode {self.game.episode_counter}, " + f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" ) - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.agent = self.game.rl_agents[0] + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -107,7 +116,7 @@ class PrimaiteRayEnv(gymnasium.Env): :type env_config: Dict[str, PrimaiteGame] """ self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) - self.env.game.episode_counter -= 1 + self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -194,13 +203,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards, From f82506023bcd978ff28e28a25491eb6b6facdee7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:29:27 +0000 Subject: [PATCH 02/14] Delete set_original_state method definitions --- src/primaite/game/game.py | 2 - src/primaite/simulator/core.py | 10 +--- src/primaite/simulator/domain/account.py | 13 ----- src/primaite/simulator/file_system/file.py | 9 ---- .../simulator/file_system/file_system.py | 11 ---- .../file_system/file_system_item_abc.py | 5 -- src/primaite/simulator/file_system/folder.py | 17 ------- src/primaite/simulator/network/container.py | 7 --- .../simulator/network/hardware/base.py | 50 ------------------- .../network/hardware/nodes/router.py | 48 ------------------ src/primaite/simulator/sim_container.py | 4 -- .../system/applications/application.py | 6 --- .../system/applications/database_client.py | 8 --- .../red_applications/data_manipulation_bot.py | 15 ------ .../applications/red_applications/dos_bot.py | 16 ------ .../system/applications/web_browser.py | 8 --- .../simulator/system/processes/process.py | 6 --- .../services/database/database_service.py | 13 ----- .../system/services/dns/dns_client.py | 7 --- .../system/services/dns/dns_server.py | 7 --- .../system/services/ftp/ftp_client.py | 7 --- .../system/services/ftp/ftp_server.py | 7 --- .../simulator/system/services/service.py | 6 --- .../system/services/web_server/web_server.py | 7 --- src/primaite/simulator/system/software.py | 19 ------- .../_simulator/_domain/test_account.py | 2 - .../_file_system/test_file_system.py | 1 - .../_simulator/_network/test_container.py | 1 - .../_red_applications/test_dos_bot.py | 2 - 29 files changed, 1 insertion(+), 313 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 091438ce..bd7ed2cd 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -386,6 +386,4 @@ class PrimaiteGame: else: _LOGGER.warning(f"agent type {agent_type} not found") - game.simulation.set_original_state() - return game diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 98a7e8db..e21ce9eb 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -153,8 +153,6 @@ class SimComponent(BaseModel): uuid: str """The component UUID.""" - _original_state: Dict = {} - def __init__(self, **kwargs): if not kwargs.get("uuid"): kwargs["uuid"] = str(uuid4()) @@ -162,15 +160,9 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - # @abstractmethod - def set_original_state(self): - """Sets the original state.""" - pass - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - for key, value in self._original_state.items(): - self.__setattr__(key, value) + pass def _init_request_manager(self) -> RequestManager: """ diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d9dad06a..186caf5b 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,19 +42,6 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "num_logons", - "num_logoffs", - "num_group_changes", - "username", - "password", - "account_type", - "enabled", - } - self._original_state = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 608a1d78..4cd5cdbb 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,15 +73,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}") - super().set_original_state() - vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ee80587d..a7252a2d 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -34,17 +34,6 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}") - for folder in self.folders.values(): - folder.set_original_state() - # Capture a list of all 'original' file uuids - original_keys = list(self.folders.keys()) - vals_to_include = {"sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_folder_uuids"] = original_keys - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index c3e1426b..fbe5f4b3 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} - self._original_state = self.model_dump(include=vals_to_keep) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 13fdc597..39c3dad8 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -49,23 +49,6 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}") - for file in self.files.values(): - file.set_original_state() - super().set_original_state() - vals_to_include = { - "scan_duration", - "scan_countdown", - "red_scan_duration", - "red_scan_countdown", - "restore_duration", - "restore_countdown", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_file_uuids"] = list(self.files.keys()) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 8989a60f..48205bbd 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,13 +45,6 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def set_original_state(self): - """Sets the original state.""" - for node in self.nodes.values(): - node.set_original_state() - for link in self.links.values(): - link.set_original_state() - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 01dd736d..68f3816d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,13 +123,6 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" super().reset_component_for_episode(episode) @@ -349,14 +342,6 @@ class SwitchPort(SimComponent): kwargs["mac_address"] = generate_mac_address() super().__init__(**kwargs) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -506,14 +491,6 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"bandwidth", "current_load"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -1033,33 +1010,6 @@ class Node(SimComponent): self.arp.nics = self.nics self.session_manager.software_manager = self.software_manager self._install_system_software() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - for software in self.software_manager.software.values(): - software.set_original_state() - - self.file_system.set_original_state() - - for nic in self.nics.values(): - nic.set_original_state() - - vals_to_include = { - "hostname", - "default_gateway", - "operating_state", - "revealed_to_red", - "start_up_duration", - "start_up_countdown", - "shut_down_duration", - "shut_down_countdown", - "is_resetting", - "node_scan_duration", - "node_scan_countdown", - "red_scan_countdown", - } - self._original_state = self.model_dump(include=vals_to_include) def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 9a34be0b..4b379be0 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -53,11 +53,6 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -101,28 +96,6 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.implicit_rule.set_original_state() - vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - - for i, rule in enumerate(self._acl): - if not rule: - continue - self._default_config[i] = {"action": rule.action.name} - if rule.src_ip_address: - self._default_config[i]["src_ip"] = str(rule.src_ip_address) - if rule.dst_ip_address: - self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) - if rule.src_port: - self._default_config[i]["src_port"] = rule.src_port.name - if rule.dst_port: - self._default_config[i]["dst_port"] = rule.dst_port.name - if rule.protocol: - self._default_config[i]["protocol"] = rule.protocol.name def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" @@ -389,11 +362,6 @@ class RouteEntry(SimComponent): metric: float = 0.0 "The cost metric for this route. Default is 0.0." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} - self._original_values = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -426,11 +394,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - self._original_state["routes_orig"] = self.routes - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.routes.clear() @@ -808,16 +771,6 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.acl.set_original_state() - self.route_table.set_original_state() - super().set_original_state() - vals_to_include = {"num_ports"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() @@ -987,7 +940,6 @@ class Router(Node): nic.ip_address = ip_address nic.subnet_mask = subnet_mask self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}") - self.set_original_state() def enable_port(self, port: int): """ diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 896861e6..18ed894c 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,10 +21,6 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - self.network.set_original_state() - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.network.reset_component_for_episode(episode) diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 322ac808..513606a9 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,12 +38,6 @@ class Application(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 2e0f4e3f..d05472d4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -30,14 +30,6 @@ class DatabaseClient(Application): kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"} - self._original_state.update(self.model_dump(include=vals_to_include)) def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index a844f059..bd4048c4 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -49,21 +49,6 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "server_ip_address", - "payload", - "server_password", - "port_scan_p_of_success", - "data_manipulation_p_of_success", - "attack_stage", - "repeat", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index dfc48dd3..d4ea1a20 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -57,22 +57,6 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def set_original_state(self): - """Set the original state of the Denial of Service Bot.""" - _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "target_ip_address", - "target_port", - "payload", - "repeat", - "attack_stage", - "max_sessions", - "port_scan_p_of_success", - "dos_intensity", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index eef0ed5d..f1dbe3ef 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -47,16 +47,8 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.set_original_state() self.run() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index b753e3ad..458a6b5c 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,12 +24,6 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index d75b4424..4159c87c 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -40,19 +40,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "password", - "connections", - "backup_server_ip", - "latest_backup_directory", - "latest_backup_file_name", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 2d3879ff..3c034705 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,13 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_server"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.dns_cache.clear() diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8decf7e9..eab94766 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,13 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_table"} - self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.dns_table.clear() diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 39bc57f0..457eaea9 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,13 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"connected"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index a82b0919..9534a5e9 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,13 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_password"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 162678a0..4102657c 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -78,12 +78,6 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index eaea6bb1..5888e72a 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,13 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"last_response_status_code"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 662db08e..1fb8c989 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -96,19 +96,6 @@ class Software(SimComponent): _patching_countdown: Optional[int] = None "Current number of ticks left to patch the software." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "name", - "health_state_actual", - "health_state_visible", - "criticality", - "patching_count", - "scanning_count", - "revealed_to_red", - } - self._original_state = self.model_dump(include=vals_to_include) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -245,12 +232,6 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 01ad3871..695b15dd 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -7,7 +7,6 @@ from primaite.simulator.domain.account import Account, AccountType @pytest.fixture(scope="function") def account() -> Account: acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER) - acct.set_original_state() return acct @@ -39,7 +38,6 @@ def test_original_state(account): account.log_on() account.log_off() account.disable() - account.set_original_state() account.log_on() state = account.describe_state() diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 9366d173..2fe3f04c 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -189,7 +189,6 @@ def test_reset_file_system(file_system): # file and folder that existed originally file_system.create_file(file_name="test_file.zip") file_system.create_folder(folder_name="test_folder") - file_system.set_original_state() # create a new file file_system.create_file(file_name="new_file.txt") diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 7667a59f..994e5a45 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -33,7 +33,6 @@ def network(example_network) -> Network: assert len(example_network.computers) is 2 assert len(example_network.servers) is 2 - example_network.set_original_state() example_network.show() return example_network diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 71489171..da29a439 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -22,7 +22,6 @@ def dos_bot() -> DoSBot: dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) - dos_bot.set_original_state() return dos_bot @@ -51,7 +50,6 @@ def test_dos_bot_reset(dos_bot): dos_bot.configure( target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True ) - dos_bot.set_original_state() dos_bot.reset_component_for_episode(episode=1) # should reset to the configured value assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") From 72f4cc0a5073e79f7b9734b27739b3264513c6f8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:56:25 +0000 Subject: [PATCH 03/14] Remove reset methods from most classes --- src/primaite/simulator/file_system/file.py | 5 --- .../simulator/file_system/file_system.py | 26 ----------- src/primaite/simulator/file_system/folder.py | 26 ----------- .../simulator/network/hardware/base.py | 9 ---- .../network/hardware/nodes/router.py | 45 ++++++------------- .../system/applications/database_client.py | 6 --- .../red_applications/data_manipulation_bot.py | 5 --- .../applications/red_applications/dos_bot.py | 5 --- .../system/applications/web_browser.py | 8 ---- .../services/database/database_service.py | 6 --- .../system/services/dns/dns_client.py | 5 --- .../system/services/dns/dns_server.py | 7 --- .../system/services/ftp/ftp_client.py | 5 --- .../system/services/ftp/ftp_server.py | 6 --- .../system/services/ntp/ntp_client.py | 13 +----- .../system/services/ntp/ntp_server.py | 10 ----- .../system/services/web_server/web_server.py | 5 --- 17 files changed, 16 insertions(+), 176 deletions(-) diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 4cd5cdbb..d9b02e8e 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -73,11 +73,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") - super().reset_component_for_episode(episode) - @property def path(self) -> str: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index a7252a2d..8fd4e5d7 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -34,32 +34,6 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") - # Move any 'original' folder that have been deleted back to folders - original_folder_uuids = self._original_state["original_folder_uuids"] - for uuid in original_folder_uuids: - if uuid in self.deleted_folders: - folder = self.deleted_folders[uuid] - self.deleted_folders.pop(uuid) - self.folders[uuid] = folder - - # Clear any other deleted folders that aren't original (have been created by agent) - self.deleted_folders.clear() - - # Now clear all non-original folders created by agent - current_folder_uuids = list(self.folders.keys()) - for uuid in current_folder_uuids: - if uuid not in original_folder_uuids: - folder = self.folders[uuid] - self.folders.pop(uuid) - - # Now reset all remaining folders - for folder in self.folders.values(): - folder.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 39c3dad8..771dc7a0 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -49,32 +49,6 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") - # Move any 'original' file that have been deleted back to files - original_file_uuids = self._original_state["original_file_uuids"] - for uuid in original_file_uuids: - if uuid in self.deleted_files: - file = self.deleted_files[uuid] - self.deleted_files.pop(uuid) - self.files[uuid] = file - - # Clear any other deleted files that aren't original (have been created by agent) - self.deleted_files.clear() - - # Now clear all non-original files created by agent - current_file_uuids = list(self.files.keys()) - for uuid in current_file_uuids: - if uuid not in original_file_uuids: - file = self.files[uuid] - self.files.pop(uuid) - - # Now reset all remaining files - for file in self.files.values(): - file.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 68f3816d..67ac42c8 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1015,15 +1015,6 @@ class Node(SimComponent): """Reset the original state of the SimComponent.""" super().reset_component_for_episode(episode) - # Reset ARP Cache - self.arp.clear() - - # Reset ICMP - self.icmp.clear() - - # Reset Session Manager - self.session_manager.clear() - # Reset File System self.file_system.reset_component_for_episode(episode) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 4b379be0..aa154ad9 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -84,9 +84,7 @@ class AccessControlList(SimComponent): implicit_action: ACLAction implicit_rule: ACLRule max_acl_rules: int = 25 - _acl: List[Optional[ACLRule]] = [None] * 24 - _default_config: Dict[int, dict] = {} - """Config dict describing how the ACL list should look at episode start""" + _acl: List[Optional[ACLRule]] = [None] * 24 # TODO: this ignores the max_acl_rules and assumes it's default def __init__(self, **kwargs) -> None: if not kwargs.get("implicit_action"): @@ -97,26 +95,6 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.implicit_rule.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - self._reset_rules_to_default() - - def _reset_rules_to_default(self) -> None: - """Clear all ACL rules and set them to the default rules config.""" - self._acl = [None] * (self.max_acl_rules - 1) - for r_num, r_cfg in self._default_config.items(): - self.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("src_ip"), - dst_ip_address=r_cfg.get("dst_ip"), - position=r_num, - ) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -394,12 +372,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.routes.clear() - self.routes = self._original_state["routes_orig"] - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -1040,7 +1012,18 @@ class Router(Node): ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"], ) + + # Add the router's default ACL rules from the config. if "acl" in cfg: - new.acl._default_config = cfg["acl"] # save the config to allow resetting - new.acl._reset_rules_to_default() # read the config and apply rules + for r_num, r_cfg in cfg["acl"].items(): + new.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + return new diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index d05472d4..25730c38 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -31,12 +31,6 @@ class DatabaseClient(Application): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - self._query_success_tracker.clear() - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index bd4048c4..5fe951b7 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -49,11 +49,6 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index d4ea1a20..9dac6b25 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -57,11 +57,6 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index f1dbe3ef..6f2c479c 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -49,11 +49,6 @@ class WebBrowser(Application): super().__init__(**kwargs) self.run() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( @@ -72,9 +67,6 @@ class WebBrowser(Application): state["history"] = [hist_item.state() for hist_item in self.history] return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - def get_webpage(self, url: Optional[str] = None) -> bool: """ Retrieve the webpage. diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 4159c87c..5425ce75 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -40,12 +40,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 3c034705..967af6b2 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,11 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_cache.clear() - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index eab94766..4d0ebbb8 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,13 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_table.clear() - for key, value in self._original_state["dns_table_orig"].items(): - self.dns_table[key] = value - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 457eaea9..7c334ced 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,11 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 9534a5e9..c5330de2 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,12 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index e8c3d0cb..5e4ae53a 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -1,6 +1,6 @@ from datetime import datetime from ipaddress import IPv4Address -from typing import Dict, Optional +from typing import Dict, List, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket @@ -49,21 +49,12 @@ class NTPClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def send( self, payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: [Port] = Port.NTP, + dest_port: List[Port] = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 0a66384a..29a320f6 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -34,16 +34,6 @@ class NTPServer(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including - resetting any stateful properties or statistics, and clearing any message - queues. - """ - pass - def receive( self, payload: NTPPacket, diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5888e72a..5e4a6f6e 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,11 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. From f933341df521feaca5e494bf739833b86d75ab28 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 23 Feb 2024 10:06:48 +0000 Subject: [PATCH 04/14] eod commit --- src/primaite/game/game.py | 10 ---------- src/primaite/simulator/core.py | 8 ++++++-- src/primaite/simulator/network/container.py | 8 ++++---- src/primaite/simulator/network/hardware/base.py | 14 +++++++------- .../simulator/network/hardware/nodes/router.py | 10 +++++----- src/primaite/simulator/sim_container.py | 4 ++-- 6 files changed, 24 insertions(+), 30 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index bd7ed2cd..72ad01e7 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -160,16 +160,6 @@ class PrimaiteGame: return True return False - def reset(self) -> None: # TODO: deprecated - remove me - """Reset the game, this will reset the simulation.""" - self.episode_counter += 1 - self.step_counter = 0 - _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(episode=self.episode_counter) - for agent in self.agents: - agent.reward_function.total_reward = 0.0 - agent.reset_agent_for_episode() - def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index e21ce9eb..b9188bf0 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -160,8 +160,12 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" + def setup_for_episode(self, episode: int): + """ + Perform any additional setup on this component that can't happen during __init__. + + For instance, some components may require for the entire network to exist before some configuration can be set. + """ pass def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 48205bbd..080a1004 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,12 +45,12 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.reset_component_for_episode(episode) + node.setup_for_episode(episode) for link in self.links.values(): - link.reset_component_for_episode(episode) + link.setup_for_episode(episode) for node in self.nodes.values(): node.power_on() @@ -171,7 +171,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.reset_component_for_episode() + link.setup_for_episode() def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 67ac42c8..e2a90db1 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,9 +123,9 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() @@ -1011,19 +1011,19 @@ class Node(SimComponent): self.session_manager.software_manager = self.software_manager self._install_system_software() - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) # Reset File System - self.file_system.reset_component_for_episode(episode) + self.file_system.setup_for_episode(episode) # Reset all Nics for nic in self.nics.values(): - nic.reset_component_for_episode(episode) + nic.setup_for_episode(episode) for software in self.software_manager.software.values(): - software.reset_component_for_episode(episode) + software.setup_for_episode(episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index aa154ad9..887bc9be 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -743,16 +743,16 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() - self.acl.reset_component_for_episode(episode) - self.route_table.reset_component_for_episode(episode) + self.acl.setup_for_episode(episode) + self.route_table.setup_for_episode(episode) for i, nic in self.ethernet_ports.items(): - nic.reset_component_for_episode(episode) + nic.setup_for_episode(episode) self.enable_port(i) - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 18ed894c..bb6132a8 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,9 +21,9 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.reset_component_for_episode(episode) + self.network.setup_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() From c115095157f27d6d7480df430c2c83e50078184d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 25 Feb 2024 16:17:12 +0000 Subject: [PATCH 05/14] Fix router from config using wrong method --- src/primaite/simulator/network/hardware/nodes/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 887bc9be..6bf80b3c 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -1016,7 +1016,7 @@ class Router(Node): # Add the router's default ACL rules from the config. if "acl" in cfg: for r_num, r_cfg in cfg["acl"].items(): - new.add_rule( + new.acl.add_rule( action=ACLAction[r_cfg["action"]], src_port=None if not (p := r_cfg.get("src_port")) else Port[p], dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], From 994dbc3501b7c50584322cf3bba9db7aa7e0b77d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 25 Feb 2024 17:44:41 +0000 Subject: [PATCH 06/14] Finalise the refactor. It works well now. --- .../config/_package_data/example_config.yaml | 5 +- src/primaite/game/game.py | 12 +++- src/primaite/notebooks/uc2_demo.ipynb | 66 +++++++++---------- src/primaite/session/environment.py | 7 +- src/primaite/simulator/network/container.py | 6 +- .../simulator/network/hardware/base.py | 10 +-- .../network/hardware/nodes/network/router.py | 2 +- src/primaite/simulator/sim_container.py | 2 +- .../system/services/web_server/web_server.py | 2 +- src/primaite/simulator/system/software.py | 7 +- 10 files changed, 65 insertions(+), 54 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f85baf10..a32696c7 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -652,12 +652,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 02d36c8a..f5649589 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -185,6 +185,10 @@ class PrimaiteGame: """Close the game, this will close the simulation.""" return NotImplemented + def setup_for_episode(self, episode: int) -> None: + """Perform any final configuration of components to make them ready for the game to start.""" + self.simulation.setup_for_episode(episode=episode) + @classmethod def from_config(cls, cfg: Dict) -> "PrimaiteGame": """Create a PrimaiteGame object from a config dictionary. @@ -258,7 +262,9 @@ class PrimaiteGame: new_service = new_node.software_manager.software[service_type] game.ref_map_services[service_ref] = new_service.uuid else: - _LOGGER.warning(f"service type not found {service_type}") + msg = f"Configuration contains an invalid service type: {service_type}" + _LOGGER.error(msg) + raise ValueError(msg) # service-dependent options if service_type == "DNSClient": if "options" in service_cfg: @@ -297,7 +303,9 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: - _LOGGER.warning(f"application type not found {application_type}") + msg = f"Configuration contains an invalid application type: {application_type}" + _LOGGER.error(msg) + raise ValueError(msg) if application_type == "DataManipulationBot": if "options" in application_cfg: diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index c4fe4c9a..460e3d27 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -335,9 +335,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", @@ -347,9 +345,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# Imports\n", @@ -372,9 +368,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# create the env\n", @@ -385,10 +379,10 @@ " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", - "game = PrimaiteGame.from_config(cfg)\n", - "env = PrimaiteGymEnv(game = game)\n", - "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", - "env.agent.flatten_obs = False\n", + " # don't flatten observations so that we can see what is going on\n", + " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", "obs, info = env.reset()\n", "print('env created successfully')\n", "pprint(obs)" @@ -422,9 +416,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "for step in range(35):\n", @@ -442,9 +434,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "pprint(obs['NODES'])" @@ -460,9 +450,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", @@ -488,9 +476,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", @@ -515,9 +501,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", @@ -540,9 +524,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "env.step(13) # Patch the database\n", @@ -582,6 +564,22 @@ "obs['ACL']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reset the cell, you can rerun the other cells to verify that the attack works the same every episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.reset()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -592,7 +590,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -606,9 +604,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 2 } diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index ad770f8f..bab81253 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -74,6 +74,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" ) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) self.agent = self.game.rl_agents[0] self.episode_counter += 1 state = self.game.get_sim_state() @@ -97,12 +98,12 @@ class PrimaiteGymEnv(gymnasium.Env): def _get_obs(self) -> ObsType: """Return the current observation.""" - if not self.agent.flatten_obs: - return self.agent.observation_manager.current_observation - else: + if self.agent.flatten_obs: unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + else: + return self.agent.observation_manager.current_observation class PrimaiteRayEnv(gymnasium.Env): diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index c3ad84c3..b5a16430 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -48,9 +48,9 @@ class Network(SimComponent): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.setup_for_episode(episode) + node.setup_for_episode(episode=episode) for link in self.links.values(): - link.setup_for_episode(episode) + link.setup_for_episode(episode=episode) for node in self.nodes.values(): node.power_on() @@ -172,7 +172,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.setup_for_episode() + link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here. def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 771c3397..1b6d611e 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -90,7 +90,7 @@ class NetworkInterface(SimComponent, ABC): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().setup_for_episode(episode) + super().setup_for_episode(episode=episode) if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() @@ -643,17 +643,17 @@ class Node(SimComponent): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().setup_for_episode(episode) + super().setup_for_episode(episode=episode) # Reset File System - self.file_system.setup_for_episode(episode) + self.file_system.setup_for_episode(episode=episode) # Reset all Nics for network_interface in self.network_interfaces.values(): - network_interface.setup_for_episode(episode) + network_interface.setup_for_episode(episode=episode) for software in self.software_manager.software.values(): - software.setup_for_episode(episode) + software.setup_for_episode(episode=episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index c299dfb7..3111a153 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1078,7 +1078,7 @@ class Router(NetworkNode): for i, _ in self.network_interface.items(): self.enable_port(i) - super().setup_for_episode(episode) + super().setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index bb6132a8..a2285d92 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -23,7 +23,7 @@ class Simulation(SimComponent): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.setup_for_episode(episode) + self.network.setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 5e4a6f6e..ce29a2f9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -118,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: + except Exception: # TODO: refactor this. Likely to cause silent bugs. # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 56a1e3d1..8864659c 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -3,7 +3,7 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder @@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +if TYPE_CHECKING: + from primaite.simulator.system.core.software_manager import SoftwareManager + class SoftwareType(Enum): """ @@ -84,7 +87,7 @@ class Software(SimComponent): "The count of times the software has been scanned, defaults to 0." revealed_to_red: bool = False "Indicates if the software has been revealed to red agent, defaults is False." - software_manager: Any = None + software_manager: "SoftwareManager" = None "An instance of Software Manager that is used by the parent node." sys_log: SysLog = None "An instance of SysLog that is used by the parent node." From 63c9a36c30adf38716759526968067ae990f1fdc Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 25 Feb 2024 18:36:20 +0000 Subject: [PATCH 07/14] Fix typos --- src/primaite/notebooks/uc2_demo.ipynb | 2 +- src/primaite/simulator/system/services/ntp/ntp_client.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 460e3d27..7c90a885 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -568,7 +568,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Reset the cell, you can rerun the other cells to verify that the attack works the same every episode." + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode." ] }, { diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 1e9dc139..ad00065c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -1,6 +1,6 @@ from datetime import datetime from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Dict, Optional from primaite import getLogger from primaite.simulator.network.protocols.ntp import NTPPacket @@ -54,7 +54,7 @@ class NTPClient(Service): payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: List[Port] = Port.NTP, + dest_port: Port = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. From e5982c4599b07ef5cf994218f4323d1105f65bc7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 10:26:28 +0000 Subject: [PATCH 08/14] Change agents list in game object to dictionary --- .../example_config_2_rl_agents.yaml | 446 +++++++++++------- src/primaite/game/game.py | 18 +- .../training_example_ray_multi_agent.ipynb | 9 +- .../training_example_ray_single_agent.ipynb | 2 +- .../notebooks/training_example_sb3.ipynb | 11 +- src/primaite/session/environment.py | 52 +- tests/conftest.py | 2 +- tests/integration_tests/game_configuration.py | 16 +- 8 files changed, 331 insertions(+), 225 deletions(-) diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 93019c9d..1ccd7b38 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -10,6 +10,8 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true game: @@ -36,9 +38,9 @@ agents: - type: NODE_APPLICATION_EXECUTE options: nodes: - - node_ref: client_2 + - node_name: client_2 applications: - - application_ref: client_2_web_browser + - application_name: WebBrowser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -54,6 +56,31 @@ agents: frequency: 4 variance: 3 + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + reward_function: + reward_components: + - type: DUMMY + + + + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -72,7 +99,7 @@ agents: - type: NODE_OS_SCAN options: nodes: - - node_ref: client_1 + - node_name: client_1 applications: - application_name: DataManipulationBot - node_name: client_2 @@ -104,25 +131,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -137,23 +160,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -184,10 +207,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -242,25 +265,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -271,22 +294,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -303,63 +326,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -407,123 +430,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... - + flatten_obs: true - ref: defender_2 team: BLUE @@ -537,25 +585,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -570,23 +614,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -617,10 +661,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -675,25 +719,25 @@ agents: action: "NODE_FILE_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 11: action: "NODE_FILE_DELETE" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 12: action: "NODE_FILE_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 file_id: 0 13: action: "NODE_SERVICE_PATCH" @@ -704,22 +748,22 @@ agents: action: "NODE_FOLDER_SCAN" options: node_id: 2 - folder_id: 1 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: node_id: 2 - folder_id: 1 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: node_id: 2 - folder_id: 1 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: node_id: 2 - folder_id: 1 + folder_id: 0 18: action: "NODE_OS_SCAN" options: @@ -736,63 +780,63 @@ agents: action: "NODE_RESET" options: node_id: 5 - 22: + 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 7 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: + 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 24: # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 3 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 25: # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 26: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 27: action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 @@ -840,122 +884,148 @@ agents: action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 39: action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 + nic_id: 0 40: action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 41: action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 + nic_id: 0 42: action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 43: action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 + nic_id: 0 44: action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 45: action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 + nic_id: 0 46: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 47: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 1 + nic_id: 0 48: action: "NETWORK_NIC_DISABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 49: action: "NETWORK_NIC_ENABLE" options: node_id: 4 - nic_id: 2 + nic_id: 1 50: action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 51: action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 + nic_id: 0 52: action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 53: action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.34 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 options: - node_ref: web_server - service_ref: web_server_web_service + node_hostname: client_1 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.33 + options: + node_hostname: client_2 agent_settings: - # ... + flatten_obs: true @@ -1032,12 +1102,13 @@ simulation: default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -1089,10 +1160,14 @@ simulation: - ref: data_manipulation_bot type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ services: - ref: client_1_dns_client type: DNSClient @@ -1109,6 +1184,13 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f5649589..8edf70ea 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -79,11 +79,11 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - self.agents: List[AbstractAgent] = [] - """List of agents.""" + self.agents: Dict[str, AbstractAgent] = {} + """Mapping from agent name to agent object.""" - self.rl_agents: List[ProxyAgent] = [] - """Subset of agent list including only the reinforcement learning agents.""" + self.rl_agents: Dict[str, ProxyAgent] = {} + """Subset of agents which are intended for reinforcement learning.""" self.step_counter: int = 0 """Current timestep within the episode.""" @@ -144,7 +144,7 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent in self.agents: + for name, agent in self.agents.items(): agent.update_observation(state) agent.update_reward(state) agent.reward_function.total_reward += agent.reward_function.current_reward @@ -158,7 +158,7 @@ class PrimaiteGame: """ agent_actions = {} - for agent in self.agents: + for name, agent in self.agents.items(): obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) @@ -396,7 +396,6 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -405,8 +404,7 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) - game.rl_agents.append(new_agent) + game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], @@ -415,8 +413,8 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) else: _LOGGER.warning(f"agent type {agent_type} not found") + game.agents[agent_cfg["ref"]] = new_agent return game diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index 0d4b6d0e..4ef02443 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -60,7 +60,7 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", " .rollouts(num_rollout_workers=0)\n", " .training(train_batch_size=128)\n", " )\n" @@ -88,6 +88,13 @@ " param_space=config\n", ").fit()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index ea006ae9..3c27bdc6 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -54,7 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "env_config = {\"cfg\":cfg}\n", + "env_config = cfg\n", "\n", "config = (\n", " PPOConfig()\n", diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index 164142b2..0472854e 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -27,9 +27,7 @@ "outputs": [], "source": [ "with open(example_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - "game = PrimaiteGame.from_config(cfg)" + " cfg = yaml.safe_load(f)\n" ] }, { @@ -76,6 +74,13 @@ "source": [ "model.save(\"deleteme\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index bab81253..f8dbab9d 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType @@ -25,12 +25,17 @@ class PrimaiteGymEnv(gymnasium.Env): """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) """Current game.""" - self.agent: ProxyAgent = self.game.rl_agents[0] - """The agent within the game that is controlled by the RL algorithm.""" + self._agent_name = next(iter(self.game.rl_agents)) + """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" self.episode_counter: int = 0 """Current episode number.""" + @property + def agent(self) -> ProxyAgent: + """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" + return self.game.rl_agents[self._agent_name] + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" # make ProxyAgent store the action chosen my the RL policy @@ -71,11 +76,10 @@ class PrimaiteGymEnv(gymnasium.Env): """Reset the environment.""" print( f"Resetting environment, episode {self.episode_counter}, " - f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" + f"avg. reward: {self.agent.reward_function.total_reward}" ) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) self.game.setup_for_episode(episode=self.episode_counter) - self.agent = self.game.rl_agents[0] self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) @@ -112,11 +116,10 @@ class PrimaiteRayEnv(gymnasium.Env): def __init__(self, env_config: Dict) -> None: """Initialise the environment. - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict """ - self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) + self.env = PrimaiteGymEnv(game_config=env_config) self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -138,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :type env_config: Dict """ - self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"]) + self.game_config: Dict = env_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) """Reference to the primaite game""" - self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} - """List of all possible agents in the environment. This list should not change!""" - self._agent_ids = list(self.agents.keys()) + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + self.episode_counter: int = 0 + """Current episode number.""" self.terminateds = set() self.truncateds = set() @@ -159,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) super().__init__() + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -182,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -196,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {"agent_actions": agent_actions} + infos = {name: {} for name, _ in self.agents.items()} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: @@ -222,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} - for name, agent in self.agents.items(): + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) return obs diff --git a/tests/conftest.py b/tests/conftest.py index 5084c339..83ac9559 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -510,6 +510,6 @@ def game_and_agent(): reward_function=reward_function, ) - game.agents.append(test_agent) + game.agents["test_agent"] = test_agent return (game, test_agent) diff --git a/tests/integration_tests/game_configuration.py b/tests/integration_tests/game_configuration.py index 3bd870e3..f3dc51bd 100644 --- a/tests/integration_tests/game_configuration.py +++ b/tests/integration_tests/game_configuration.py @@ -42,20 +42,20 @@ def test_example_config(): assert len(game.agents) == 4 # red, blue and 2 green agents # green agent 1 - assert game.agents[0].agent_name == "client_2_green_user" - assert isinstance(game.agents[0], RandomAgent) + assert "client_2_green_user" in game.agents + assert isinstance(game.agents["client_2_green_user"], RandomAgent) # green agent 2 - assert game.agents[1].agent_name == "client_1_green_user" - assert isinstance(game.agents[1], RandomAgent) + assert "client_1_green_user" in game.agents + assert isinstance(game.agents["client_1_green_user"], RandomAgent) # red agent - assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot" - assert isinstance(game.agents[2], DataManipulationAgent) + assert "client_1_data_manipulation_red_bot" in game.agents + assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) # blue agent - assert game.agents[3].agent_name == "defender" - assert isinstance(game.agents[3], ProxyAgent) + assert "defender" in game.agents + assert isinstance(game.agents["defender"], ProxyAgent) network: Network = game.simulation.network From ccb10f1160cee454216c205e15c08d2ecc0c02d7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 11:02:37 +0000 Subject: [PATCH 09/14] Update docs based on reset refactor --- CHANGELOG.md | 1 + docs/index.rst | 1 + docs/source/environment.rst | 10 ++++++++++ docs/source/game_layer.rst | 5 +++++ docs/source/primaite_session.rst | 5 +++++ 5 files changed, 22 insertions(+) create mode 100644 docs/source/environment.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index 01e45d2e..d2a582be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Made environment reset completely recreate the game object. - Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack. - Changed the data manipulation scenario to include a second green agent on client 1. - Refactored actions and observations to be configurable via object name, instead of UUID. diff --git a/docs/index.rst b/docs/index.rst index 9eae8adc..08e0ac21 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -108,6 +108,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/simulation source/game_layer source/config + source/environment .. toctree:: :caption: Developer information: diff --git a/docs/source/environment.rst b/docs/source/environment.rst new file mode 100644 index 00000000..87e7f060 --- /dev/null +++ b/docs/source/environment.rst @@ -0,0 +1,10 @@ +RL Environments +*************** + +RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs: + +* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation `_. +* Ray Single agent API - For training a single Ray RLLib agent +* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. + +There is a Jupyter notebook which demonstrates integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index cdae17dd..1f2921fe 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -20,6 +20,11 @@ The game layer is responsible for managing agents and getting them to interface PrimAITE Session ^^^^^^^^^^^^^^^ +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + ``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types. Agents diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 706397b6..87a3f03d 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -4,6 +4,11 @@ .. _run a primaite session: +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + Run a PrimAITE Session ====================== From a5043a8fbe01b38699cedf677129a17bec32655d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 12:15:53 +0000 Subject: [PATCH 10/14] Modify tests based on refactoring --- src/primaite/session/session.py | 6 ++-- src/primaite/simulator/network/airspace.py | 5 --- .../network/hardware/nodes/host/host_node.py | 5 --- .../hardware/nodes/network/firewall.py | 18 ---------- .../network/hardware/nodes/network/switch.py | 6 ---- .../hardware/nodes/network/wireless_router.py | 3 -- .../assets/configs/bad_primaite_session.yaml | 7 ++-- .../configs/eval_only_primaite_session.yaml | 9 ++--- tests/assets/configs/multi_agent_session.yaml | 10 +++--- .../assets/configs/test_primaite_session.yaml | 9 ++--- .../configs/train_only_primaite_session.yaml | 9 ++--- .../environments/test_sb3_environment.py | 3 +- .../_simulator/_domain/test_account.py | 19 ----------- .../_file_system/test_file_system.py | 31 ----------------- .../_simulator/_network/test_container.py | 34 ------------------- .../_red_applications/test_dos_bot.py | 28 --------------- 16 files changed, 28 insertions(+), 174 deletions(-) diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 5c663cfd..b8f80e95 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -101,11 +101,11 @@ class PrimaiteSession: # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayEnv(env_config=cfg) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayMARLEnv(env_config=cfg) elif sess.training_options.rl_framework == "SB3": - sess.env = PrimaiteGymEnv(game=game) + sess.env = PrimaiteGymEnv(game_config=cfg) sess.policy = PolicyABC.from_config(sess.training_options, session=sess) if agent_load_path: diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 724b8728..d264f751 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def enable(self): """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 3f34f736..329a5fa0 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -213,11 +213,6 @@ class NIC(IPWiredNetworkInterface): return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def receive_frame(self, frame: Frame) -> bool: """ Attempt to receive and process a network frame from the connected Link. diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 22effa2a..f2305652 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -109,24 +109,6 @@ class Firewall(Router): sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound" ) - self.set_original_state() - - def set_original_state(self): - """Set the original state for the Firewall.""" - super().set_original_state() - vals_to_include = { - "internal_port", - "external_port", - "dmz_port", - "internal_inbound_acl", - "internal_outbound_acl", - "dmz_inbound_acl", - "dmz_outbound_acl", - "external_inbound_acl", - "external_outbound_acl", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def describe_state(self) -> Dict: """ Describes the current state of the Firewall. diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 33e6ee9a..557ea287 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface): _connected_node: Optional[Switch] = None "The Switch to which the SwitchPort is connected." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index dd0b58d3..91833d6a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -122,8 +122,6 @@ class WirelessRouter(Router): self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) - self.set_original_state() - @property def wireless_access_point(self) -> WirelessAccessPoint: """ @@ -166,7 +164,6 @@ class WirelessRouter(Router): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured WAP {network_interface}") - self.set_original_state() self.wireless_access_point.frequency = frequency # Set operating frequency self.wireless_access_point.enable() # Re-enable the WAP with new settings diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 5bdc3273..c76aeef6 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -589,15 +589,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 8361e318..1cb59f87 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -593,15 +593,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -624,7 +625,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 87bd9d1c..b1b15372 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -1043,16 +1043,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - - ref: database_server type: server @@ -1074,7 +1074,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 76190a64..e5f9d544 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -599,15 +599,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -630,7 +631,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5d004c7e..10e088d8 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -600,15 +600,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -631,7 +632,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 91cf5c1e..dc5d10e9 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -17,8 +17,7 @@ def test_sb3_compatibility(): with open(example_config_path(), "r") as f: cfg = yaml.safe_load(f) - game = PrimaiteGame.from_config(cfg) - gym = PrimaiteGymEnv(game=game) + gym = PrimaiteGymEnv(game_config=cfg) model = PPO("MlpPolicy", gym) model.learn(total_timesteps=1000) diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 695b15dd..786fe851 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -12,20 +12,6 @@ def account() -> Account: def test_original_state(account): """Test the original state - see if it resets properly""" - account.log_on() - account.log_off() - account.disable() - - state = account.describe_state() - assert state["num_logons"] is 1 - assert state["num_logoffs"] is 1 - assert state["num_group_changes"] is 0 - assert state["username"] is "Jake" - assert state["password"] is "totally_hashed_password" - assert state["account_type"] is AccountType.USER.value - assert state["enabled"] is False - - account.reset_component_for_episode(episode=1) state = account.describe_state() assert state["num_logons"] is 0 assert state["num_logoffs"] is 0 @@ -39,11 +25,6 @@ def test_original_state(account): account.log_off() account.disable() - account.log_on() - state = account.describe_state() - assert state["num_logons"] is 2 - - account.reset_component_for_episode(episode=2) state = account.describe_state() assert state["num_logons"] is 1 assert state["num_logoffs"] is 1 diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 2fe3f04c..4defc80c 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -185,37 +185,6 @@ def test_get_file(file_system): file_system.show(full=True) -def test_reset_file_system(file_system): - # file and folder that existed originally - file_system.create_file(file_name="test_file.zip") - file_system.create_folder(folder_name="test_folder") - - # create a new file - file_system.create_file(file_name="new_file.txt") - - # create a new folder - file_system.create_folder(folder_name="new_folder") - - # delete the file that existed originally - file_system.delete_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None - - # delete the folder that existed originally - file_system.delete_folder(folder_name="test_folder") - assert file_system.get_folder(folder_name="test_folder") is None - - # reset - file_system.reset_component_for_episode(episode=1) - - # deleted original file and folder should be back - assert file_system.get_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_folder(folder_name="test_folder") - - # new file and folder should be removed - assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None - assert file_system.get_folder(folder_name="new_folder") is None - - @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index bf79677e..2cfc3f11 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -44,40 +44,6 @@ def test_describe_state(network): assert len(state["links"]) is 6 -def test_reset_network(network): - """ - Test that the network is properly reset. - - TODO: make sure that once implemented - any installed/uninstalled services, processes, apps, - etc are also removed/reinstalled - - """ - state_before = network.describe_state() - - client_1: Computer = network.get_node_by_hostname("client_1") - server_1: Computer = network.get_node_by_hostname("server_1") - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - - client_1.power_off() - assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - server_1.power_off() - assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - assert network.describe_state() != state_before - - network.reset_component_for_episode(episode=1) - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - # don't worry if UUIDs change - a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) - b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) - assert a == b - - def test_creating_container(): """Check that we can create a network container""" net = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index 1f28244d..4bfd28d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -27,34 +27,6 @@ def test_dos_bot_creation(dos_bot): assert dos_bot is not None -def test_dos_bot_reset(dos_bot): - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - - # should reset the relevant items - dos_bot.reset_component_for_episode(episode=0) - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - dos_bot.reset_component_for_episode(episode=1) - # should reset to the configured value - assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") - assert dos_bot.target_port is Port.HTTP - assert dos_bot.payload == "payload" - assert dos_bot.repeat is True - - def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node: Computer = dos_bot.parent assert dos_bot_node.operating_state is NodeOperatingState.ON From 2076b011ba8837f8e85e362a68c71b1010aa8e0a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 14:26:47 +0000 Subject: [PATCH 11/14] Put back default router rules --- src/primaite/simulator/network/hardware/nodes/network/router.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 3111a153..52f38eb6 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1039,6 +1039,8 @@ class Router(NetworkNode): self.connect_nic(network_interface) self.network_interface[i] = network_interface + self._set_default_acl() + def _install_system_software(self): """ Installs essential system software and network services on the router. From f9cc5af7aab3d822dcc23472f65c74a04ced4650 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 16:06:58 +0000 Subject: [PATCH 12/14] Not sure how this test was passing before --- .../_primaite/_simulator/_system/test_software.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index e77cd895..6f680012 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,12 +2,14 @@ from typing import Dict import pytest +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.software import Software, SoftwareHealthState +from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import IOSoftware, SoftwareHealthState -class TestSoftware(Software): +class TestSoftware(Service): def describe_state(self) -> Dict: pass @@ -15,7 +17,11 @@ class TestSoftware(Software): @pytest.fixture(scope="function") def software(file_system): return TestSoftware( - name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + name="TestSoftware", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="test_service"), + protocol=IPProtocol.TCP, ) From 33d2ecc26a4ac125607477c0a2846afa1b6fc728 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 16:58:43 +0000 Subject: [PATCH 13/14] Apply suggestions from code review. --- docs/source/environment.rst | 2 +- .../config/_package_data/example_config_2_rl_agents.yaml | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/source/environment.rst b/docs/source/environment.rst index 87e7f060..2b76572d 100644 --- a/docs/source/environment.rst +++ b/docs/source/environment.rst @@ -7,4 +7,4 @@ RL environments are the objects that directly interface with RL libraries such a * Ray Single agent API - For training a single Ray RLLib agent * Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. -There is a Jupyter notebook which demonstrates integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. +There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 1ccd7b38..c1e077be 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -1,11 +1,17 @@ training_config: rl_framework: RLLIB_multi_agent - # rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 256 + deterministic_eval: false n_agents: 2 agent_references: - defender_1 - defender_2 + io_settings: save_checkpoints: true checkpoint_interval: 5 From 8730330f73bf5b38ade30bfc18c23ee3c5523367 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Feb 2024 10:14:31 +0000 Subject: [PATCH 14/14] Apply PR suggestions --- src/primaite/config/_package_data/example_config.yaml | 2 +- .../_package_data/example_config_2_rl_agents.yaml | 2 +- src/primaite/game/game.py | 10 ++++++---- .../simulator/network/hardware/nodes/network/router.py | 1 - .../simulator/system/services/web_server/web_server.py | 2 +- .../environments/test_sb3_environment.py | 1 - 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index a32696c7..ebee4980 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -14,7 +14,7 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false - save_pcap_logs: true + save_pcap_logs: false save_sys_logs: true diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index c1e077be..992c3a1a 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -16,7 +16,7 @@ io_settings: save_checkpoints: true checkpoint_interval: 5 save_step_metadata: false - save_pcap_logs: true + save_pcap_logs: false save_sys_logs: true diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8edf70ea..baa84d1d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -133,7 +133,7 @@ class PrimaiteGame: self.update_agents(sim_state) # Apply all actions to simulation as requests - agent_actions = self.apply_agent_actions() # noqa + self.apply_agent_actions() # Advance timestep self.advance_timestep() @@ -144,7 +144,7 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for name, agent in self.agents.items(): + for _, agent in self.agents.items(): agent.update_observation(state) agent.update_reward(state) agent.reward_function.total_reward += agent.reward_function.current_reward @@ -158,7 +158,7 @@ class PrimaiteGame: """ agent_actions = {} - for name, agent in self.agents.items(): + for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward action_choice, options = agent.get_action(obs, rew) @@ -414,7 +414,9 @@ class PrimaiteGame: agent_settings=agent_settings, ) else: - _LOGGER.warning(f"agent type {agent_type} not found") + msg(f"Configuration error: {agent_type} is not a valid agent type.") + _LOGGER.error(msg) + raise ValueError(msg) game.agents[agent_cfg["ref"]] = new_agent return game diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 52f38eb6..aa6eec3a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1076,7 +1076,6 @@ class Router(NetworkNode): :param episode: The episode number for which the router is being reset. """ self.software_manager.arp.clear() - # self.acl.reset_component_for_episode(episode) for i, _ in self.network_interface.items(): self.enable_port(i) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index ce29a2f9..5e7591e9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -118,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: # TODO: refactor this. Likely to cause silent bugs. + except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index dc5d10e9..c48ddbc9 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -11,7 +11,6 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv -# @pytest.mark.skip(reason="no way of currently testing this") def test_sb3_compatibility(): """Test that the Gymnasium environment can be used with an SB3 agent.""" with open(example_config_path(), "r") as f: