From e9eef2b4c09d12d7f42624a9667b7be1597f6b80 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Mar 2024 11:16:27 +0000 Subject: [PATCH 1/8] #2350: add num_access, num_file_deletions and num_creations to file system --- src/primaite/simulator/file_system/file.py | 19 +++++++ .../simulator/file_system/file_system.py | 25 +++++++-- src/primaite/simulator/file_system/folder.py | 1 + .../system/applications/application.py | 10 ++++ .../system/applications/database_client.py | 2 + .../red_applications/data_manipulation_bot.py | 2 + .../system/applications/web_browser.py | 2 + .../_file_system/test_file_system.py | 52 ++++++++++++++++++- 8 files changed, 108 insertions(+), 5 deletions(-) diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index d9b02e8e..0897178d 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -38,6 +38,8 @@ class File(FileSystemItemABC): "The Path if real is True." sim_root: Optional[Path] = None "Root path of the simulation." + num_access: int = 0 + "Number of times the file was accessed in the current step." def __init__(self, **kwargs): """ @@ -93,11 +95,23 @@ class File(FileSystemItemABC): return os.path.getsize(self.sim_path) return self.sim_size + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the file. + + :param timestep: The current timestep of the simulation. + """ + super().apply_timestep(timestep=timestep) + + # reset the number of accesses to 0 + self.num_access = 0 + def describe_state(self) -> Dict: """Produce a dictionary describing the current state of this object.""" state = super().describe_state() state["size"] = self.size state["file_type"] = self.file_type.name + state["num_access"] = self.num_access return state def scan(self) -> None: @@ -106,6 +120,7 @@ class File(FileSystemItemABC): self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}") return + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}") self.visible_health_status = self.health_status @@ -160,6 +175,7 @@ class File(FileSystemItemABC): if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}") @@ -173,6 +189,7 @@ class File(FileSystemItemABC): if self.health_status == FileSystemItemHealthStatus.GOOD: self.health_status = FileSystemItemHealthStatus.CORRUPT + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}") @@ -185,6 +202,7 @@ class File(FileSystemItemABC): if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}") @@ -194,5 +212,6 @@ class File(FileSystemItemABC): self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}") return + self.num_access += 1 # file was accessed self.deleted = True self.sys_log.info(f"File deleted {self.folder_name}/{self.name}") diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 8fd4e5d7..52144c72 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -27,6 +27,10 @@ class FileSystem(SimComponent): "Instance of SysLog used to create system logs." sim_root: Path "Root path of the simulation." + num_file_creations: int = 0 + "Number of file creations in the current step." + num_file_deletions: int = 0 + "Number of file deletions in the current step." def __init__(self, **kwargs): super().__init__(**kwargs) @@ -248,6 +252,8 @@ class FileSystem(SimComponent): ) folder.add_file(file) self._file_request_manager.add_request(name=file.name, request_type=RequestType(func=file._request_manager)) + # increment file creation + self.num_file_creations += 1 return file def get_file(self, folder_name: str, file_name: str, include_deleted: Optional[bool] = False) -> Optional[File]: @@ -308,6 +314,8 @@ class FileSystem(SimComponent): if folder: file = folder.get_file(file_name) if file: + # increment file creation + self.num_file_deletions += 1 folder.remove_file(file) def delete_file_by_id(self, folder_uuid: str, file_uuid: str): @@ -337,15 +345,14 @@ class FileSystem(SimComponent): """ file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) if file: - src_folder = file.folder - # remove file from src - src_folder.remove_file(file) + self.delete_file(folder_name=file.folder_name, file_name=file.name) dst_folder = self.get_folder(folder_name=dst_folder_name) if not dst_folder: dst_folder = self.create_folder(dst_folder_name) # add file to dst dst_folder.add_file(file) + self.num_file_creations += 1 if file.real: old_sim_path = file.sim_path file.sim_path = file.sim_root / file.path @@ -373,6 +380,10 @@ class FileSystem(SimComponent): folder_name=dst_folder.name, **file.model_dump(exclude={"uuid", "folder_id", "folder_name", "sim_path"}), ) + self.num_file_creations += 1 + # increment access counter + file.num_access += 1 + dst_folder.add_file(file_copy, force=True) if file.real: @@ -390,12 +401,20 @@ class FileSystem(SimComponent): state = super().describe_state() state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()} state["deleted_folders"] = {folder.name: folder.describe_state() for folder in self.deleted_folders.values()} + state["num_file_creations"] = self.num_file_creations + state["num_file_deletions"] = self.num_file_deletions return state def apply_timestep(self, timestep: int) -> None: """Apply time step to FileSystem and its child folders and files.""" super().apply_timestep(timestep=timestep) + # reset number of file creations + self.num_file_creations = 0 + + # reset number of file deletions + self.num_file_deletions = 0 + # apply timestep to folders for folder_id in self.folders: self.folders[folder_id].apply_timestep(timestep=timestep) diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 771dc7a0..3ddc1e5f 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -131,6 +131,7 @@ class Folder(FileSystemItemABC): file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: self.visible_health_status = FileSystemItemHealthStatus.CORRUPT + self.visible_health_status = self.health_status def _reveal_to_red_timestep(self) -> None: """Apply reveal to red timestep.""" diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 513606a9..74013681 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -59,6 +59,16 @@ class Application(IOSoftware): ) return state + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the application. + + :param timestep: The current timestep of the simulation. + """ + super().apply_timestep(timestep=timestep) + + self.num_executions = 0 # reset number of executions + def _can_perform_action(self) -> bool: """ Checks if the application can perform actions. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 7b259ff4..302aca7e 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -76,6 +76,8 @@ class DatabaseClient(Application): if not self._can_perform_action(): return False + self.num_executions += 1 # trying to connect counts as an execution + if not connection_id: connection_id = str(uuid4()) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index ee98ea8e..cce9fe8d 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -194,6 +194,8 @@ class DataManipulationBot(Application): """ if not self._can_perform_action(): return + + self.num_executions += 1 if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") self._logon() diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 9fa86328..90eda426 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -80,6 +80,8 @@ class WebBrowser(Application): if not self._can_perform_action(): return False + self.num_executions += 1 # trying to connect counts as an execution + # reset latest response self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 4defc80c..05824834 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,7 +1,9 @@ import pytest +from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.file_system.file_type import FileType +from primaite.simulator.file_system.folder import Folder def test_create_folder_and_file(file_system): @@ -14,8 +16,15 @@ def test_create_folder_and_file(file_system): assert len(file_system.get_folder("test_folder").files) == 1 + assert file_system.num_file_creations == 1 + assert file_system.get_folder("test_folder").get_file("test_file.txt") + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) @@ -23,24 +32,37 @@ def test_create_file_no_folder(file_system): """Tests that creating a file without a folder creates a folder and sets that as the file's parent.""" file = file_system.create_file(file_name="test_file.txt", size=10) assert len(file_system.folders) is 1 + assert file_system.num_file_creations == 1 assert file_system.get_folder("root").get_file("test_file.txt") == file assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT assert file_system.get_folder("root").get_file("test_file.txt").size == 10 + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) def test_delete_file(file_system): """Tests that a file can be deleted.""" - file_system.create_file(file_name="test_file.txt") + file = file_system.create_file(file_name="test_file.txt") assert len(file_system.folders) == 1 assert len(file_system.get_folder("root").files) == 1 file_system.delete_file(folder_name="root", file_name="test_file.txt") + assert file.num_access == 1 + assert file_system.num_file_deletions == 1 assert len(file_system.folders) == 1 assert len(file_system.get_folder("root").files) == 0 assert len(file_system.get_folder("root").deleted_files) == 1 + file_system.apply_timestep(0) + + # num file deletions should reset + assert file_system.num_file_deletions == 0 + file_system.show(full=True) @@ -54,6 +76,7 @@ def test_delete_non_existent_file(file_system): # deleting should not change how many files are in folder file_system.delete_file(folder_name="root", file_name="does_not_exist!") + assert file_system.num_file_deletions == 0 # should still only be one folder assert len(file_system.folders) == 1 @@ -96,6 +119,7 @@ def test_create_duplicate_file(file_system): assert len(file_system.folders) is 2 file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + assert file_system.num_file_creations == 1 assert len(file_system.get_folder("test_folder").files) == 1 @@ -103,6 +127,7 @@ def test_create_duplicate_file(file_system): file_system.create_file(file_name="test_file.txt", folder_name="test_folder") assert len(file_system.get_folder("test_folder").files) == 1 + assert file_system.num_file_creations == 1 file_system.show(full=True) @@ -136,13 +161,24 @@ def test_move_file(file_system): assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 0 + assert file_system.num_file_deletions == 0 + assert file_system.num_file_creations == 1 file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") + assert file_system.num_file_deletions == 1 + assert file_system.num_file_creations == 2 + assert file.num_access == 1 assert len(file_system.get_folder("src_folder").files) == 0 assert len(file_system.get_folder("dst_folder").files) == 1 assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid + file_system.apply_timestep(0) + + # num file creations and deletions should reset + assert file_system.num_file_creations == 0 + assert file_system.num_file_deletions == 0 + file_system.show(full=True) @@ -152,17 +188,25 @@ def test_copy_file(file_system): file_system.create_folder(folder_name="dst_folder") file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True) + assert file_system.num_file_creations == 1 original_uuid = file.uuid assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 0 file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") + assert file_system.num_file_creations == 2 + assert file.num_access == 1 assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 1 assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) @@ -172,13 +216,17 @@ def test_get_file(file_system): file1: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") file2: File = file_system.create_file(file_name="test_file2.txt", folder_name="test_folder") - folder.remove_file(file2) + file_system.delete_file("test_folder", "test_file2.txt") + # file 2 was accessed before being deleted + assert file2.num_access == 1 assert file_system.get_file_by_id(file_uuid=file1.uuid, folder_uuid=folder.uuid) is not None assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid) is None assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid, include_deleted=True) is not None assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None + assert file2.num_access == 1 # cannot access deleted file + file_system.delete_folder(folder_name="test_folder") assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None From b13725721d2a636e77334954a1f59de16c11fcbb Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Mar 2024 13:49:00 +0000 Subject: [PATCH 2/8] #2350: splitting observations into separate files --- src/primaite/game/agent/interface.py | 21 +- .../game/agent/observations/__init__.py | 0 .../agent/observations/agent_observations.py | 188 ++++++++++++++ .../agent/observations/observation_manager.py | 73 ++++++ .../agent/{ => observations}/observations.py | 234 ------------------ .../game/agent/scripted_agents/__init__.py | 0 .../data_manipulation_bot.py | 0 .../probabilistic_agent.py} | 2 +- .../agent/scripted_agents/random_agent.py | 21 ++ src/primaite/game/game.py | 6 +- tests/conftest.py | 3 +- ...software_installation_and_configuration.py | 2 +- .../game_layer/test_actions.py | 18 +- .../game_layer/test_observations.py | 2 +- .../network/test_capture_nmne.py | 2 +- .../_game/_agent/test_probabilistic_agent.py | 5 +- 16 files changed, 300 insertions(+), 277 deletions(-) create mode 100644 src/primaite/game/agent/observations/__init__.py create mode 100644 src/primaite/game/agent/observations/agent_observations.py create mode 100644 src/primaite/game/agent/observations/observation_manager.py rename src/primaite/game/agent/{ => observations}/observations.py (79%) create mode 100644 src/primaite/game/agent/scripted_agents/__init__.py rename src/primaite/game/agent/{ => scripted_agents}/data_manipulation_bot.py (100%) rename src/primaite/game/agent/{scripted_agents.py => scripted_agents/probabilistic_agent.py} (97%) create mode 100644 src/primaite/game/agent/scripted_agents/random_agent.py diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 88848479..e641fabb 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction if TYPE_CHECKING: @@ -146,23 +146,10 @@ class AbstractAgent(ABC): class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" - pass - - -class RandomAgent(AbstractScriptedAgent): - """Agent that ignores its observation and acts completely at random.""" - + @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: - """Sample the action space randomly. - - :param obs: Current observation for this agent, not used in RandomAgent - :type obs: ObsType - :param timestep: The current simulation timestep, not used in RandomAgent - :type timestep: int - :return: Action formatted in CAOS format - :rtype: Tuple[str, Dict] - """ - return self.action_manager.get_action(self.action_manager.space.sample()) + """Return an action to be taken in the environment.""" + return super().get_action(obs=obs, timestep=timestep) class ProxyAgent(AbstractAgent): diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py new file mode 100644 index 00000000..522cdb59 --- /dev/null +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -0,0 +1,188 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import ( + AbstractObservation, + AclObservation, + ICSObservation, + LinkObservation, + NodeObservation, + NullObservation, +) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class UC2BlueObservation(AbstractObservation): + """Container for all observations used by the blue agent in UC2. + + TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler + for the purpose of compiling several observation components. + """ + + def __init__( + self, + nodes: List[NodeObservation], + links: List[LinkObservation], + acl: AclObservation, + ics: ICSObservation, + where: Optional[List[str]] = None, + ) -> None: + """Initialise UC2 blue observation. + + :param nodes: List of node observations + :type nodes: List[NodeObservation] + :param links: List of link observations + :type links: List[LinkObservation] + :param acl: The Access Control List observation + :type acl: AclObservation + :param ics: The ICS observation + :type ics: ICSObservation + :param where: Where in the simulation state dict to find information. Not used in this particular observation + because it only compiles other observations and doesn't contribute any new information, defaults to None + :type where: Optional[List[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.nodes: List[NodeObservation] = nodes + self.links: List[LinkObservation] = links + self.acl: AclObservation = acl + self.ics: ICSObservation = ics + + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, + "ACL": self.acl.default_observation, + "ICS": self.ics.default_observation, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + obs = {} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} + obs["ACL"] = self.acl.observe(state) + obs["ICS"] = self.ics.observe(state) + + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), + "ACL": self.acl.space, + "ICS": self.ics.space, + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": + """Create UC2 blue observation from a config. + + :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, + links, ACL and ICS observations. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed UC2 blue observation + :rtype: UC2BlueObservation + """ + node_configs = config["nodes"] + + num_services_per_node = config["num_services_per_node"] + num_folders_per_node = config["num_folders_per_node"] + num_files_per_folder = config["num_files_per_folder"] + num_nics_per_node = config["num_nics_per_node"] + nodes = [ + NodeObservation.from_config( + config=n, + game=game, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) + for n in node_configs + ] + + link_configs = config["links"] + links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] + + acl_config = config["acl"] + acl = AclObservation.from_config(config=acl_config, game=game) + + ics_config = config["ics"] + ics = ICSObservation.from_config(config=ics_config, game=game) + new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) + return new + + +class UC2RedObservation(AbstractObservation): + """Container for all observations used by the red agent in UC2.""" + + def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: + super().__init__() + self.where: Optional[List[str]] = where + self.nodes: List[NodeObservation] = nodes + + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + if self.where is None: + return self.default_observation + + obs = {} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": + """ + Create UC2 red observation from a config. + + :param config: Dictionary containing the configuration for this UC2 red observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + """ + node_configs = config["nodes"] + nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] + return cls(nodes=nodes, where=["network"]) + + +class UC2GreenObservation(NullObservation): + """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" + + pass diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py new file mode 100644 index 00000000..400345fa --- /dev/null +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -0,0 +1,73 @@ +from typing import Dict, TYPE_CHECKING + +from gymnasium.core import ObsType + +from primaite.game.agent.observations.agent_observations import ( + UC2BlueObservation, + UC2GreenObservation, + UC2RedObservation, +) +from primaite.game.agent.observations.observations import AbstractObservation + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ObservationManager: + """ + Manage the observations of an Agent. + + The observation space has the purpose of: + 1. Reading the outputted state from the PrimAITE Simulation. + 2. Selecting parts of the simulation state that are requested by the simulation config + 3. Formatting this information so an agent can use it to make decisions. + """ + + # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed + # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next + # refactor. + + def __init__(self, observation: AbstractObservation) -> None: + """Initialise observation space. + + :param observation: Observation object + :type observation: AbstractObservation + """ + self.obs: AbstractObservation = observation + self.current_observation: ObsType + + def update(self, state: Dict) -> Dict: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + """ + self.current_observation = self.obs.observe(state) + return self.current_observation + + @property + def space(self) -> None: + """Gymnasium space object describing the observation space shape.""" + return self.obs.space + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": + """Create observation space from a config. + + :param config: Dictionary containing the configuration for this observation space. + It should contain the key 'type' which selects which observation class to use (from a choice of: + UC2BlueObservation, UC2RedObservation, UC2GreenObservation) + The other key is 'options' which are passed to the constructor of the selected observation class. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + """ + if config["type"] == "UC2BlueObservation": + return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) + elif config["type"] == "UC2RedObservation": + return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) + elif config["type"] == "UC2GreenObservation": + return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) + else: + raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations/observations.py similarity index 79% rename from src/primaite/game/agent/observations.py rename to src/primaite/game/agent/observations/observations.py index 82e11fe0..6d6614f4 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -4,7 +4,6 @@ from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces -from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -822,236 +821,3 @@ class ICSObservation(NullObservation): """ICS observation placeholder, currently not implemented so always returns a single 0.""" pass - - -class UC2BlueObservation(AbstractObservation): - """Container for all observations used by the blue agent in UC2. - - TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler - for the purpose of compiling several observation components. - """ - - def __init__( - self, - nodes: List[NodeObservation], - links: List[LinkObservation], - acl: AclObservation, - ics: ICSObservation, - where: Optional[List[str]] = None, - ) -> None: - """Initialise UC2 blue observation. - - :param nodes: List of node observations - :type nodes: List[NodeObservation] - :param links: List of link observations - :type links: List[LinkObservation] - :param acl: The Access Control List observation - :type acl: AclObservation - :param ics: The ICS observation - :type ics: ICSObservation - :param where: Where in the simulation state dict to find information. Not used in this particular observation - because it only compiles other observations and doesn't contribute any new information, defaults to None - :type where: Optional[List[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.nodes: List[NodeObservation] = nodes - self.links: List[LinkObservation] = links - self.acl: AclObservation = acl - self.ics: ICSObservation = ics - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, - "ACL": self.acl.default_observation, - "ICS": self.ics.default_observation, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs["ACL"] = self.acl.observe(state) - obs["ICS"] = self.ics.observe(state) - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), - "ACL": self.acl.space, - "ICS": self.ics.space, - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": - """Create UC2 blue observation from a config. - - :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, - links, ACL and ICS observations. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed UC2 blue observation - :rtype: UC2BlueObservation - """ - node_configs = config["nodes"] - - num_services_per_node = config["num_services_per_node"] - num_folders_per_node = config["num_folders_per_node"] - num_files_per_folder = config["num_files_per_folder"] - num_nics_per_node = config["num_nics_per_node"] - nodes = [ - NodeObservation.from_config( - config=n, - game=game, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - for n in node_configs - ] - - link_configs = config["links"] - links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] - - acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, game=game) - - ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, game=game) - new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) - return new - - -class UC2RedObservation(AbstractObservation): - """Container for all observations used by the red agent in UC2.""" - - def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: - super().__init__() - self.where: Optional[List[str]] = where - self.nodes: List[NodeObservation] = nodes - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": - """ - Create UC2 red observation from a config. - - :param config: Dictionary containing the configuration for this UC2 red observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] - return cls(nodes=nodes, where=["network"]) - - -class UC2GreenObservation(NullObservation): - """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" - - pass - - -class ObservationManager: - """ - Manage the observations of an Agent. - - The observation space has the purpose of: - 1. Reading the outputted state from the PrimAITE Simulation. - 2. Selecting parts of the simulation state that are requested by the simulation config - 3. Formatting this information so an agent can use it to make decisions. - """ - - # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed - # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next - # refactor. - - def __init__(self, observation: AbstractObservation) -> None: - """Initialise observation space. - - :param observation: Observation object - :type observation: AbstractObservation - """ - self.obs: AbstractObservation = observation - self.current_observation: ObsType - - def update(self, state: Dict) -> Dict: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - """ - self.current_observation = self.obs.observe(state) - return self.current_observation - - @property - def space(self) -> None: - """Gymnasium space object describing the observation space shape.""" - return self.obs.space - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": - """Create observation space from a config. - - :param config: Dictionary containing the configuration for this observation space. - It should contain the key 'type' which selects which observation class to use (from a choice of: - UC2BlueObservation, UC2RedObservation, UC2GreenObservation) - The other key is 'options' which are passed to the constructor of the selected observation class. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) - else: - raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/scripted_agents/__init__.py b/src/primaite/game/agent/scripted_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py similarity index 100% rename from src/primaite/game/agent/data_manipulation_bot.py rename to src/primaite/game/agent/scripted_agents/data_manipulation_bot.py diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py similarity index 97% rename from src/primaite/game/agent/scripted_agents.py rename to src/primaite/game/agent/scripted_agents/probabilistic_agent.py index 5111df32..9cddc978 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -7,7 +7,7 @@ from gymnasium.core import ObsType from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractScriptedAgent -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py new file mode 100644 index 00000000..34a4b5ac --- /dev/null +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -0,0 +1,21 @@ +from typing import Dict, Tuple + +from gymnasium.core import ObsType + +from primaite.game.agent.interface import AbstractScriptedAgent + + +class RandomAgent(AbstractScriptedAgent): + """Agent that ignores its observation and acts completely at random.""" + + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Sample the action space randomly. + + :param obs: Current observation for this agent, not used in RandomAgent + :type obs: ObsType + :param timestep: The current simulation timestep, not used in RandomAgent + :type timestep: int + :return: Action formatted in CAOS format + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.action_manager.space.sample()) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 394a8154..33f9186b 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -6,11 +6,11 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.agent.scripted_agents import ProbabilisticAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC diff --git a/tests/conftest.py b/tests/conftest.py index a117a1ef..20600e73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,8 @@ from _pytest.monkeypatch import MonkeyPatch from primaite import getLogger, PRIMAITE_PATHS from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent -from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations.observations import ICSObservation from primaite.game.agent.rewards import RewardFunction from primaite.game.game import PrimaiteGame from primaite.session.session import PrimaiteSession diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 3aff59af..f993af5f 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -5,8 +5,8 @@ from typing import Union import yaml from primaite.config.load import data_manipulation_config_path -from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import ProxyAgent, RandomAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 8911632c..740fb491 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -10,28 +10,14 @@ # 4. Check that the simulation has changed in the way that I expect. # 5. Repeat for all actions. -from typing import Dict, Tuple +from typing import Tuple import pytest -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, ProxyAgent -from primaite.game.agent.observations import ICSObservation, ObservationManager -from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.simulator.system.services.dns.dns_server import DNSServer -from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index d1301759..b6aed30b 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -1,6 +1,6 @@ from gymnasium import spaces -from primaite.game.agent.observations import FileObservation +from primaite.game.agent.observations.observations import FileObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 698bfc72..4bbde32f 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,4 +1,4 @@ -from primaite.game.agent.observations import NicObservation +from primaite.game.agent.observations.observations import NicObservation from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.sim_container import Simulation diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index 73228e36..c556cfad 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -1,7 +1,8 @@ from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations.observations import ICSObservation from primaite.game.agent.rewards import RewardFunction -from primaite.game.agent.scripted_agents import ProbabilisticAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent def test_probabilistic_agent(): From ba58204542ffce55d49a1a8107c543f9ebc99ad0 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Mar 2024 14:08:35 +0000 Subject: [PATCH 3/8] #2350: split observations into smaller files --- .../agent/observations/agent_observations.py | 2 +- .../observations/file_system_observations.py | 177 ++++++++ .../agent/observations/node_observations.py | 199 +++++++++ .../game/agent/observations/observations.py | 412 ------------------ .../observations/software_observation.py | 71 +++ .../game_layer/test_observations.py | 2 +- 6 files changed, 449 insertions(+), 414 deletions(-) create mode 100644 src/primaite/game/agent/observations/file_system_observations.py create mode 100644 src/primaite/game/agent/observations/node_observations.py create mode 100644 src/primaite/game/agent/observations/software_observation.py diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py index 522cdb59..70a83881 100644 --- a/src/primaite/game/agent/observations/agent_observations.py +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -2,12 +2,12 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces +from primaite.game.agent.observations.node_observations import NodeObservation from primaite.game.agent.observations.observations import ( AbstractObservation, AclObservation, ICSObservation, LinkObservation, - NodeObservation, NullObservation, ) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py new file mode 100644 index 00000000..277bc51f --- /dev/null +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class FileObservation(AbstractObservation): + """Observation of a file on a node in the network.""" + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """ + Initialise file observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',,'files',] + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.default_observation: spaces.Space = {"health_status": 0} + "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + file_state = access_from_nested_dict(state, self.where) + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"health_status": file_state["visible_status"]} + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"health_status": spaces.Discrete(6)}) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": + """Create file observation from a config. + + :param config: Dictionary containing the configuration for this file observation. + :type config: Dict + :param game: _description_ + :type game: PrimaiteGame + :param parent_where: _description_, defaults to None + :type parent_where: _type_, optional + :return: _description_ + :rtype: _type_ + """ + return cls(where=parent_where + ["files", config["file_name"]]) + + +class FolderObservation(AbstractObservation): + """Folder observation, including files inside of the folder.""" + + def __init__( + self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 + ) -> None: + """Initialise folder Observation, including files inside the folder. + + :param where: Where in the simulation state dictionary to find the relevant information for this folder. + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',] + :type where: Optional[List[str]] + :param max_files: As size of the space must remain static, define max files that can be in this folder + , defaults to 5 + :type max_files: int, optional + :param file_positions: Defines the positioning within the observation space of particular files. This ensures + that even if new files are created, the existing files will always occupy the same space in the observation + space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the + observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same + name, it will take the position defined in this dict. Defaults to {} + :type file_positions: Dict[int, str], optional + """ + super().__init__() + + self.where: Optional[Tuple[str]] = where + + self.files: List[FileObservation] = files + while len(self.files) < num_files_per_folder: + self.files.append(FileObservation()) + while len(self.files) > num_files_per_folder: + truncated_file = self.files.pop() + msg = f"Too many files in folder observation. Truncating file {truncated_file}" + _LOGGER.warning(msg) + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + ) -> "FolderObservation": + """Create folder observation from a config. Also creates child file observations. + + :param config: Dictionary containing the configuration for this folder observation. Includes the name of the + folder and the files inside of it. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this folder's + parent node. A typical location for a node ``where`` can be: + ['network','nodes',,'file_system'] + :type parent_where: Optional[List[str]] + :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed folder observation + :rtype: FolderObservation + """ + where = parent_where + ["folders", config["folder_name"]] + + file_configs = config["files"] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] + + return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py new file mode 100644 index 00000000..93c6765b --- /dev/null +++ b/src/primaite/game/agent/observations/node_observations.py @@ -0,0 +1,199 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.file_system_observations import FolderObservation +from primaite.game.agent.observations.observations import AbstractObservation, NicObservation +from primaite.game.agent.observations.software_observation import ServiceObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NodeObservation(AbstractObservation): + """Observation of a node in the network. Includes services, folders and NICs.""" + + def __init__( + self, + where: Optional[Tuple[str]] = None, + services: List[ServiceObservation] = [], + folders: List[FolderObservation] = [], + network_interfaces: List[NicObservation] = [], + logon_status: bool = False, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> None: + """ + Configurable observation for a node in the simulation. + + :param where: Where in the simulation state dictionary for find relevant information for this observation. + A typical location for a node looks like this: + ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] + :type where: List[str], optional + :param services: Mapping between position in observation space and service name, defaults to {} + :type services: Dict[int,str], optional + :param max_services: Max number of services that can be presented in observation space for this node + , defaults to 2 + :type max_services: int, optional + :param folders: Mapping between position in observation space and folder name, defaults to {} + :type folders: Dict[int,str], optional + :param max_folders: Max number of folders in this node's obs space, defaults to 2 + :type max_folders: int, optional + :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} + :type network_interfaces: Dict[int,str], optional + :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 + :type max_nics: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.services: List[ServiceObservation] = services + while len(self.services) < num_services_per_node: + # add empty service observation without `where` parameter so it always returns default (blank) observation + self.services.append(ServiceObservation()) + while len(self.services) > num_services_per_node: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + # truncate service list + + self.folders: List[FolderObservation] = folders + # add empty folder observation without `where` parameter that will always return default (blank) observations + while len(self.folders) < num_folders_per_node: + self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) + while len(self.folders) > num_folders_per_node: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" + _LOGGER.warning(msg) + + self.network_interfaces: List[NicObservation] = network_interfaces + while len(self.network_interfaces) < num_nics_per_node: + self.network_interfaces.append(NicObservation()) + while len(self.network_interfaces) > num_nics_per_node: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" + _LOGGER.warning(msg) + + self.logon_status: bool = logon_status + + self.default_observation: Dict = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + } + if self.logon_status: + self.default_observation["logon_status"] = 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NETWORK_INTERFACES"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + + if self.logon_status: + obs["logon_status"] = 0 + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + space_shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NETWORK_INTERFACES": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + } + if self.logon_status: + space_shape["logon_status"] = spaces.Discrete(3) + + return spaces.Dict(space_shape) + + @classmethod + def from_config( + cls, + config: Dict, + game: "PrimaiteGame", + parent_where: Optional[List[str]] = None, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> "NodeObservation": + """Create node observation from a config. Also creates child service, folder and NIC observations. + + :param config: Dictionary containing the configuration for this node observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this node's parent + network. A typical location for it would be: ['network',] + :type parent_where: Optional[List[str]] + :param num_services_per_node: How many spaces for services are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_services_per_node: int, optional + :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_folders_per_node: int, optional + :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed node observation + :rtype: NodeObservation + """ + node_hostname = config["node_hostname"] + if parent_where is None: + where = ["network", "nodes", node_hostname] + else: + where = parent_where + ["nodes", node_hostname] + + svc_configs = config.get("services", {}) + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] + folder_configs = config.get("folders", {}) + folders = [ + FolderObservation.from_config( + config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder + ) + for c in folder_configs + ] + # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. + nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] + network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] + logon_status = config.get("logon_status", False) + return cls( + where=where, + services=services, + folders=folders, + network_interfaces=network_interfaces, + logon_status=logon_status, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 6d6614f4..10e69ea5 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -46,128 +46,6 @@ class AbstractObservation(ABC): pass -class FileObservation(AbstractObservation): - """Observation of a file on a node in the network.""" - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """ - Initialise file observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',,'files',] - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.default_observation: spaces.Space = {"health_status": 0} - "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - file_state = access_from_nested_dict(state, self.where) - if file_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return {"health_status": file_state["visible_status"]} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"health_status": spaces.Discrete(6)}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": - """Create file observation from a config. - - :param config: Dictionary containing the configuration for this file observation. - :type config: Dict - :param game: _description_ - :type game: PrimaiteGame - :param parent_where: _description_, defaults to None - :type parent_where: _type_, optional - :return: _description_ - :rtype: _type_ - """ - return cls(where=parent_where + ["files", config["file_name"]]) - - -class ServiceObservation(AbstractObservation): - """Observation of a service in the network.""" - - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} - "Default observation is what should be returned when the service doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise service observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'services', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - service_state = access_from_nested_dict(state, self.where) - if service_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], - } - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ServiceObservation": - """Create service observation from a config. - - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation - :rtype: ServiceObservation - """ - return cls(where=parent_where + ["services", config["service_name"]]) - - class LinkObservation(AbstractObservation): """Observation of a link in the network.""" @@ -238,111 +116,6 @@ class LinkObservation(AbstractObservation): return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -class FolderObservation(AbstractObservation): - """Folder observation, including files inside of the folder.""" - - def __init__( - self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 - ) -> None: - """Initialise folder Observation, including files inside of the folder. - - :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',] - :type where: Optional[List[str]] - :param max_files: As size of the space must remain static, define max files that can be in this folder - , defaults to 5 - :type max_files: int, optional - :param file_positions: Defines the positioning within the observation space of particular files. This ensures - that even if new files are created, the existing files will always occupy the same space in the observation - space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the - observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same - name, it will take the position defined in this dict. Defaults to {} - :type file_positions: Dict[int, str], optional - """ - super().__init__() - - self.where: Optional[Tuple[str]] = where - - self.files: List[FileObservation] = files - while len(self.files) < num_files_per_folder: - self.files.append(FileObservation()) - while len(self.files) > num_files_per_folder: - truncated_file = self.files.pop() - msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warning(msg) - - self.default_observation = { - "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - folder_state = access_from_nested_dict(state, self.where) - if folder_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - health_status = folder_state["health_status"] - - obs = {} - - obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 - ) -> "FolderObservation": - """Create folder observation from a config. Also creates child file observations. - - :param config: Dictionary containing the configuration for this folder observation. Includes the name of the - folder and the files inside of it. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node ``where`` can be: - ['network','nodes',,'file_system'] - :type parent_where: Optional[List[str]] - :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed folder observation - :rtype: FolderObservation - """ - where = parent_where + ["folders", config["folder_name"]] - - file_configs = config["files"] - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] - - return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) - - class NicObservation(AbstractObservation): """Observation of a Network Interface Card (NIC) in the network.""" @@ -444,191 +217,6 @@ class NicObservation(AbstractObservation): return cls(where=parent_where + ["NICs", config["nic_num"]]) -class NodeObservation(AbstractObservation): - """Observation of a node in the network. Includes services, folders and NICs.""" - - def __init__( - self, - where: Optional[Tuple[str]] = None, - services: List[ServiceObservation] = [], - folders: List[FolderObservation] = [], - network_interfaces: List[NicObservation] = [], - logon_status: bool = False, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> None: - """ - Configurable observation for a node in the simulation. - - :param where: Where in the simulation state dictionary for find relevant information for this observation. - A typical location for a node looks like this: - ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] - :type where: List[str], optional - :param services: Mapping between position in observation space and service name, defaults to {} - :type services: Dict[int,str], optional - :param max_services: Max number of services that can be presented in observation space for this node - , defaults to 2 - :type max_services: int, optional - :param folders: Mapping between position in observation space and folder name, defaults to {} - :type folders: Dict[int,str], optional - :param max_folders: Max number of folders in this node's obs space, defaults to 2 - :type max_folders: int, optional - :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} - :type network_interfaces: Dict[int,str], optional - :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 - :type max_nics: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.services: List[ServiceObservation] = services - while len(self.services) < num_services_per_node: - # add empty service observation without `where` parameter so it always returns default (blank) observation - self.services.append(ServiceObservation()) - while len(self.services) > num_services_per_node: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - # truncate service list - - self.folders: List[FolderObservation] = folders - # add empty folder observation without `where` parameter that will always return default (blank) observations - while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) - while len(self.folders) > num_folders_per_node: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NicObservation] = network_interfaces - while len(self.network_interfaces) < num_nics_per_node: - self.network_interfaces.append(NicObservation()) - while len(self.network_interfaces) > num_nics_per_node: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warning(msg) - - self.logon_status: bool = logon_status - - self.default_observation: Dict = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, - } - if self.logon_status: - self.default_observation["logon_status"] = 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NETWORK_INTERFACES"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - - if self.logon_status: - obs["logon_status"] = 0 - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - space_shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NETWORK_INTERFACES": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - } - if self.logon_status: - space_shape["logon_status"] = spaces.Discrete(3) - - return spaces.Dict(space_shape) - - @classmethod - def from_config( - cls, - config: Dict, - game: "PrimaiteGame", - parent_where: Optional[List[str]] = None, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> "NodeObservation": - """Create node observation from a config. Also creates child service, folder and NIC observations. - - :param config: Dictionary containing the configuration for this node observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this node's parent - network. A typical location for it would be: ['network',] - :type parent_where: Optional[List[str]] - :param num_services_per_node: How many spaces for services are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_services_per_node: int, optional - :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_folders_per_node: int, optional - :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed node observation - :rtype: NodeObservation - """ - node_hostname = config["node_hostname"] - if parent_where is None: - where = ["network", "nodes", node_hostname] - else: - where = parent_where + ["nodes", node_hostname] - - svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] - folder_configs = config.get("folders", {}) - folders = [ - FolderObservation.from_config( - config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder - ) - for c in folder_configs - ] - # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. - nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] - network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] - logon_status = config.get("logon_status", False) - return cls( - where=where, - services=services, - folders=folders, - network_interfaces=network_interfaces, - logon_status=logon_status, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - - class AclObservation(AbstractObservation): """Observation of an Access Control List (ACL) in the network.""" diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py new file mode 100644 index 00000000..eae9dc1f --- /dev/null +++ b/src/primaite/game/agent/observations/software_observation.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ServiceObservation(AbstractObservation): + """Observation of a service in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} + "Default observation is what should be returned when the service doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise service observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'services', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": service_state["operating_state"], + "health_status": service_state["health_state_visible"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ServiceObservation": + """Create service observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ServiceObservation + """ + return cls(where=parent_where + ["services", config["service_name"]]) diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index b6aed30b..f52b52f7 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -1,6 +1,6 @@ from gymnasium import spaces -from primaite.game.agent.observations.observations import FileObservation +from primaite.game.agent.observations.file_system_observations import FileObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation From 61aa24212847763737e2ef2759b05343ee0b5ef4 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Mar 2024 14:48:31 +0000 Subject: [PATCH 4/8] #2350: tests + application --- .../observations/software_observation.py | 92 +++++++++++++++++++ src/primaite/simulator/file_system/folder.py | 2 +- .../game_layer/observations/__init__.py | 0 .../test_file_system_observations.py | 68 ++++++++++++++ .../observations/test_node_observations.py | 43 +++++++++ .../observations/test_observations.py | 35 +++++++ .../test_software_observations.py | 66 +++++++++++++ 7 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/game_layer/observations/__init__.py create mode 100644 tests/integration_tests/game_layer/observations/test_file_system_observations.py create mode 100644 tests/integration_tests/game_layer/observations/test_node_observations.py create mode 100644 tests/integration_tests/game_layer/observations/test_observations.py create mode 100644 tests/integration_tests/game_layer/observations/test_software_observations.py diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index eae9dc1f..ff61714a 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -69,3 +69,95 @@ class ServiceObservation(AbstractObservation): :rtype: ServiceObservation """ return cls(where=parent_where + ["services", config["service_name"]]) + + +class ApplicationObservation(AbstractObservation): + """Observation of an application in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0} + "Default observation is what should be returned when the application doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise application observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'applications', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + app_state = access_from_nested_dict(state, self.where) + if app_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": app_state["operating_state"], + "health_status": app_state["health_state_visible"], + "num_executions": self._categorise_num_executions(app_state["num_executions"]), + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(6), + "num_executions": spaces.Discrete(4), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ApplicationObservation": + """Create application observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ApplicationObservation + """ + return cls(where=parent_where + ["services", config["application_name"]]) + + @classmethod + def _categorise_num_executions(cls, num_executions: int) -> int: + """ + Categorise the number of executions of an application. + + Helps classify the number of application executions into different categories. + + Current categories: + - 0: Application is never executed + - 1: Application is executed a low number of times (1-5) + - 2: Application is executed often (6-10) + - 3: Application is executed a high number of times (more than 10) + + :param: num_executions: Number of times the application is executed + """ + if num_executions > 10: + return 3 + elif num_executions > 5: + return 2 + elif num_executions > 0: + return 1 + return 0 diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 3ddc1e5f..529bfe11 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -130,7 +130,7 @@ class Folder(FileSystemItemABC): file = self.get_file_by_id(file_uuid=file_id) file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: - self.visible_health_status = FileSystemItemHealthStatus.CORRUPT + self.health_status = FileSystemItemHealthStatus.CORRUPT self.visible_health_status = self.health_status def _reveal_to_red_timestep(self) -> None: diff --git a/tests/integration_tests/game_layer/observations/__init__.py b/tests/integration_tests/game_layer/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py new file mode 100644 index 00000000..808007cc --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -0,0 +1,68 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_file_observation(simulation): + """Test the file observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file on the pc + file = pc.file_system.create_file(file_name="dog.png") + + dog_file_obs = FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] + ) + + assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) + + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # good initial + + file.corrupt() + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan file so this changes + + file.scan() + file.apply_timestep(0) # apply time step + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # corrupted + + +def test_folder_observation(simulation): + """Test the folder observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file and folder on the pc + folder = pc.file_system.create_folder("test_folder") + file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") + + root_folder_obs = FolderObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] + ) + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("FILES") is not None + assert observation_state.get("health_status") == 1 + + file.corrupt() # corrupt just the file + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan folder to change this + + folder.scan() + for i in range(folder.scan_duration + 1): + folder.apply_timestep(i) # apply as many timesteps as needed for a scan + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py new file mode 100644 index 00000000..835202c6 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -0,0 +1,43 @@ +import copy +from uuid import uuid4 + +import pytest + +from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_node_observation(simulation): + """Test a Node observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 1 # computer is on + + assert observation_state.get("SERVICES") is not None + assert observation_state.get("FOLDERS") is not None + assert observation_state.get("NETWORK_INTERFACES") is not None + + # turn off computer + pc.power_off() + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 4 # shutting down + + for i in range(pc.shut_down_duration + 1): + pc.apply_timestep(i) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 2 diff --git a/tests/integration_tests/game_layer/observations/test_observations.py b/tests/integration_tests/game_layer/observations/test_observations.py new file mode 100644 index 00000000..eccda238 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_observations.py @@ -0,0 +1,35 @@ +import pytest + +from primaite.game.agent.observations.observations import NicObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_nic(simulation): + """Test the NIC observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic: NIC = pc.network_interface[1] + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 1 # enabled + assert observation_state.get("nmne") is not None + assert observation_state["nmne"].get("inbound") == 0 + assert observation_state["nmne"].get("outbound") == 0 + + nic.disable() + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 2 # disabled diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py new file mode 100644 index 00000000..17fc386f --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -0,0 +1,66 @@ +import pytest + +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_service_observation(simulation): + """Test the service observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(NTPServer) + + ntp_server = pc.software_manager.software.get("NTPServer") + assert ntp_server + + service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + + observation_state = service_obs.observe(simulation.describe_state()) + + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 1 # running + + ntp_server.restart() + observation_state = service_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 6 # resetting + + +def test_application_observation(simulation): + """Test the application observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(DatabaseClient) + + web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") + assert web_browser + + app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + + web_browser.close() + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 2 # stopped + assert observation_state.get("num_executions") == 0 + + web_browser.run() + web_browser.scan() # scan to update health status + web_browser.get_webpage("test") + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 + assert observation_state.get("operating_status") == 1 # running + assert observation_state.get("num_executions") == 1 From cc721056d89563d64aae94ae9c936480a7c6388a Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 8 Mar 2024 19:32:07 +0000 Subject: [PATCH 5/8] #2350: configurable NMNE category thresholds --- .../_package_data/data_manipulation.yaml | 5 + .../agent/observations/nic_observations.py | 175 ++++++++++++++++++ .../agent/observations/node_observations.py | 3 +- .../game/agent/observations/observations.py | 102 ---------- src/primaite/game/game.py | 7 +- ...software_installation_and_configuration.py | 11 +- .../test_game_options_config.py | 25 +++ .../observations/test_observations.py | 42 ++++- .../network/test_capture_nmne.py | 2 +- 9 files changed, 261 insertions(+), 111 deletions(-) create mode 100644 src/primaite/game/agent/observations/nic_observations.py create mode 100644 tests/integration_tests/configuration_file_parsing/test_game_options_config.py diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index dffb40ea..47204878 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -30,6 +30,11 @@ game: - ICMP - TCP - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 agents: - ref: client_2_green_user diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py new file mode 100644 index 00000000..39298ffe --- /dev/null +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -0,0 +1,175 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import CAPTURE_NMNE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NicObservation(AbstractObservation): + """Observation of a Network Interface Card (NIC) in the network.""" + + low_nmne_threshold: int = 0 + """The minimum number of malicious network events to be considered low.""" + med_nmne_threshold: int = 5 + """The minimum number of malicious network events to be considered medium.""" + high_nmne_threshold: int = 10 + """The minimum number of malicious network events to be considered high.""" + + @property + def default_observation(self) -> Dict: + """The default NIC observation dict.""" + data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"nmne": {"inbound": 0, "outbound": 0}}) + + return data + + def __init__( + self, + where: Optional[Tuple[str]] = None, + low_nmne_threshold: Optional[int] = 0, + med_nmne_threshold: Optional[int] = 5, + high_nmne_threshold: Optional[int] = 10, + ) -> None: + """Initialise NIC observation. + + :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical + example may look like this: + ['network','nodes',,'NICs',] + If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. + :type where: Optional[Tuple[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: + self._validate_nmne_categories( + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) + + def _validate_nmne_categories( + self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10 + ): + """ + Validates the nmne threshold config. + + If the configuration is valid, the thresholds will be set, otherwise, an exception is raised. + + :param: low_nmne_threshold: The minimum number of malicious network events to be considered low + :param: med_nmne_threshold: The minimum number of malicious network events to be considered medium + :param: high_nmne_threshold: The minimum number of malicious network events to be considered high + """ + if high_nmne_threshold <= med_nmne_threshold: + raise Exception( + f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater " + f"than medium nmne count ({med_nmne_threshold})" + ) + + if med_nmne_threshold <= low_nmne_threshold: + raise Exception( + f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater " + f"than low nmne count ({low_nmne_threshold})" + ) + + self.high_nmne_threshold = high_nmne_threshold + self.med_nmne_threshold = med_nmne_threshold + self.low_nmne_threshold = low_nmne_threshold + + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (default 1-5 events). + - 2: Moderate number of MNEs (default 6-10 events). + - 3: High number of MNEs (default more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > self.high_nmne_threshold: + return 3 + elif nmne_count > self.med_nmne_threshold: + return 2 + elif nmne_count > self.low_nmne_threshold: + return 1 + return 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + nic_state = access_from_nested_dict(state, self.where) + + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + else: + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"nmne": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) + obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) + return obs_dict + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "nic_status": spaces.Discrete(3), + "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": + """Create NIC observation from a config. + + :param config: Dictionary containing the configuration for this NIC observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent + node. A typical location for a node ``where`` can be: ['network','nodes',] + :type parent_where: Optional[List[str]] + :return: Constructed NIC observation + :rtype: NicObservation + """ + low_nmne_threshold = None + med_nmne_threshold = None + high_nmne_threshold = None + + if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"): + threshold = game.options.thresholds["nmne"] + + low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None + med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None + high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None + + return cls( + where=parent_where + ["NICs", config["nic_num"]], + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 93c6765b..f211a6b5 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -4,7 +4,8 @@ from gymnasium import spaces from primaite import getLogger from primaite.game.agent.observations.file_system_observations import FolderObservation -from primaite.game.agent.observations.observations import AbstractObservation, NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.agent.observations.observations import AbstractObservation from primaite.game.agent.observations.software_observation import ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 10e69ea5..6236b00d 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -7,7 +7,6 @@ from gymnasium import spaces from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.nmne import CAPTURE_NMNE _LOGGER = getLogger(__name__) @@ -116,107 +115,6 @@ class LinkObservation(AbstractObservation): return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -class NicObservation(AbstractObservation): - """Observation of a Network Interface Card (NIC) in the network.""" - - @property - def default_observation(self) -> Dict: - """The default NIC observation dict.""" - data = {"nic_status": 0} - if CAPTURE_NMNE: - data.update({"nmne": {"inbound": 0, "outbound": 0}}) - - return data - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise NIC observation. - - :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical - example may look like this: - ['network','nodes',,'NICs',] - If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. - :type where: Optional[Tuple[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def _categorise_mne_count(self, nmne_count: int) -> int: - """ - Categorise the number of Malicious Network Events (NMNEs) into discrete bins. - - This helps in classifying the severity or volume of MNEs into manageable levels for the agent. - - Bins are defined as follows: - - 0: No MNEs detected (0 events). - - 1: Low number of MNEs (1-5 events). - - 2: Moderate number of MNEs (6-10 events). - - 3: High number of MNEs (more than 10 events). - - :param nmne_count: Number of MNEs detected. - :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. - """ - if nmne_count > 10: - return 3 - elif nmne_count > 5: - return 2 - elif nmne_count > 0: - return 1 - return 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - nic_state = access_from_nested_dict(state, self.where) - - if nic_state is NOT_PRESENT_IN_STATE: - return self.default_observation - else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} - if CAPTURE_NMNE: - obs_dict.update({"nmne": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) - return obs_dict - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "nic_status": spaces.Discrete(3), - "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": - """Create NIC observation from a config. - - :param config: Dictionary containing the configuration for this NIC observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent - node. A typical location for a node ``where`` can be: ['network','nodes',] - :type parent_where: Optional[List[str]] - :return: Constructed NIC observation - :rtype: NicObservation - """ - return cls(where=parent_where + ["NICs", config["nic_num"]]) - - class AclObservation(AbstractObservation): """Observation of an Access Control List (ACL) in the network.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 33f9186b..3edb8651 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,6 +1,6 @@ """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from pydantic import BaseModel, ConfigDict @@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") max_episode_length: int = 256 + """Maximum number of episodes for the PrimAITE game.""" ports: List[str] + """A whitelist of available ports in the simulation.""" protocols: List[str] + """A whitelist of available protocols in the simulation.""" + thresholds: Optional[Dict] = {} + """A dict containing the thresholds used for determining what is acceptable during observations.""" class PrimaiteGame: diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index f993af5f..a5fcb372 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -5,8 +5,9 @@ from typing import Union import yaml from primaite.config.load import data_manipulation_config_path -from primaite.game.agent.interface import ProxyAgent, RandomAgent +from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -43,15 +44,15 @@ def test_example_config(): # green agent 1 assert "client_2_green_user" in game.agents - assert isinstance(game.agents["client_2_green_user"], RandomAgent) + assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent) # green agent 2 assert "client_1_green_user" in game.agents - assert isinstance(game.agents["client_1_green_user"], RandomAgent) + assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent) # red agent - assert "client_1_data_manipulation_red_bot" in game.agents - assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) + assert "data_manipulation_attacker" in game.agents + assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent) # blue agent assert "defender" in game.agents diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py new file mode 100644 index 00000000..adbbf2b5 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -0,0 +1,25 @@ +from pathlib import Path +from typing import Union + +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_thresholds(): + """Test that the game options can be parsed correctly.""" + game = load_config(data_manipulation_config_path()) + + assert game.options.thresholds is not None diff --git a/tests/integration_tests/game_layer/observations/test_observations.py b/tests/integration_tests/game_layer/observations/test_observations.py index eccda238..97df7882 100644 --- a/tests/integration_tests/game_layer/observations/test_observations.py +++ b/tests/integration_tests/game_layer/observations/test_observations.py @@ -1,6 +1,6 @@ import pytest -from primaite.game.agent.observations.observations import NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.sim_container import Simulation @@ -33,3 +33,43 @@ def test_nic(simulation): nic.disable() observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 2 # disabled + + +def test_nic_categories(simulation): + """Test the NIC observation nmne count categories.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + assert nic_obs.high_nmne_threshold == 10 # default + assert nic_obs.med_nmne_threshold == 5 # default + assert nic_obs.low_nmne_threshold == 0 # default + + nic_obs = NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + assert nic_obs.high_nmne_threshold == 9 + assert nic_obs.med_nmne_threshold == 6 + assert nic_obs.low_nmne_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=9, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=9, + high_nmne_threshold=9, + ) diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 4bbde32f..32d4ee8f 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,4 +1,4 @@ -from primaite.game.agent.observations.observations import NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.sim_container import Simulation From a228a099175aeb40acaad33af87d13b19a6c34ef Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Sun, 10 Mar 2024 15:13:37 +0000 Subject: [PATCH 6/8] #2350: documentation --- docs/source/configuration/game.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/configuration/game.rst b/docs/source/configuration/game.rst index e43ea224..828571a7 100644 --- a/docs/source/configuration/game.rst +++ b/docs/source/configuration/game.rst @@ -23,6 +23,11 @@ This section defines high-level settings that apply across the game, currently i - ICMP - TCP - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 ``max_episode_length`` ---------------------- @@ -44,3 +49,8 @@ See :ref:`List of Ports ` for a list of ports. A list of protocols that the Reinforcement Learning agent(s) are able to see in the observation space. See :ref:`List of IPProtocols ` for a list of protocols. + +``thresholds`` +-------------- + +These are used to determine the thresholds of high, medium and low categories for counted observation occurrences. From cd6d6325db51ab7857efaf8af4fba03f06f79aa9 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 11 Mar 2024 17:47:33 +0000 Subject: [PATCH 7/8] #2350: add tests to check spaces + acl obs test + nmne space changes --- .../_package_data/data_manipulation.yaml | 2 - .../agent/observations/nic_observations.py | 29 ++++++-- .../observations/software_observation.py | 2 +- .../observations/test_acl_observations.py | 66 +++++++++++++++++ .../test_file_system_observations.py | 4 +- .../observations/test_link_observations.py | 73 +++++++++++++++++++ ...servations.py => test_nic_observations.py} | 22 ++++++ .../observations/test_node_observations.py | 3 + .../test_software_observations.py | 4 + 9 files changed, 193 insertions(+), 12 deletions(-) create mode 100644 tests/integration_tests/game_layer/observations/test_acl_observations.py create mode 100644 tests/integration_tests/game_layer/observations/test_link_observations.py rename tests/integration_tests/game_layer/observations/{test_observations.py => test_nic_observations.py} (76%) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 47204878..a3a7e44a 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -22,8 +22,6 @@ io_settings: game: max_episode_length: 256 ports: - - ARP - - DNS - HTTP - POSTGRES_SERVER protocols: diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 39298ffe..735b41d4 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -20,6 +20,8 @@ class NicObservation(AbstractObservation): high_nmne_threshold: int = 10 """The minimum number of malicious network events to be considered high.""" + global CAPTURE_NMNE + @property def default_observation(self) -> Dict: """The default NIC observation dict.""" @@ -47,6 +49,15 @@ class NicObservation(AbstractObservation): super().__init__() self.where: Optional[Tuple[str]] = where + global CAPTURE_NMNE + if CAPTURE_NMNE: + self.nmne_inbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + self.nmne_outbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: self._validate_nmne_categories( low_nmne_threshold=low_nmne_threshold, @@ -128,19 +139,21 @@ class NicObservation(AbstractObservation): inbound_count = inbound_keywords.get("*", 0) outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count) + obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count return obs_dict @property def space(self) -> spaces.Space: """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "nic_status": spaces.Discrete(3), - "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), - } - ) + space = spaces.Dict({"nic_status": spaces.Discrete(3)}) + + if CAPTURE_NMNE: + space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + + return space @classmethod def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index ff61714a..6caf791c 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -51,7 +51,7 @@ class ServiceObservation(AbstractObservation): @property def space(self) -> spaces.Space: """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod def from_config( diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py new file mode 100644 index 00000000..93867edd --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -0,0 +1,66 @@ +import pytest + +from primaite.game.agent.observations.observations import AclObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_acl_observations(simulation): + """Test the ACL rule observations.""" + router: Router = simulation.network.get_node_by_hostname("router_1") + client_1: Computer = simulation.network.get_node_by_hostname("client_1") + server: Computer = simulation.network.get_node_by_hostname("server_1") + + # quick set up of ntp + client_1.software_manager.install(NTPClient) + ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient") + ntp_client.configure(server.network_interface.get(1).ip_address) + server.software_manager.install(NTPServer) + + # add router acl rule + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + + acl_obs = AclObservation( + where=["network", "nodes", router.hostname, "acl", "acl"], + node_ip_to_id={}, + ports=["NTP", "HTTP", "POSTGRES_SERVER"], + protocols=["TCP", "UDP", "ICMP"], + ) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1) + assert rule_obs.get("permission") == 1 # permit = 1 deny = 2 + assert rule_obs.get("source_node_id") == 1 # applies to all source nodes + assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes + assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs) + assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2 + assert rule_obs.get("protocol") == 1 # 1 = No Protocol + + router.acl.remove_rule(1) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 + assert rule_obs.get("permission") == 0 + assert rule_obs.get("source_node_id") == 0 + assert rule_obs.get("dest_node_id") == 0 + assert rule_obs.get("source_port") == 0 + assert rule_obs.get("dest_port") == 0 + assert rule_obs.get("protocol") == 0 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 808007cc..35bb95fd 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -26,7 +26,7 @@ def test_file_observation(simulation): where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] ) - assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) + assert dog_file_obs.space["health_status"] == spaces.Discrete(6) observation_state = dog_file_obs.observe(simulation.describe_state()) assert observation_state.get("health_status") == 1 # good initial @@ -52,6 +52,8 @@ def test_folder_observation(simulation): where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] ) + assert root_folder_obs.space["health_status"] == spaces.Discrete(6) + observation_state = root_folder_obs.observe(simulation.describe_state()) assert observation_state.get("FILES") is not None assert observation_state.get("health_status") == 1 diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py new file mode 100644 index 00000000..bfe4d5cc --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -0,0 +1,73 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.observations import LinkObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation() -> Simulation: + sim = Simulation() + + network = Network() + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Connect Computer and Server + network.connect(computer.network_interface[1], server.network_interface[1]) + + # Should be linked + assert next(iter(network.links.values())).is_up + + assert computer.ping(server.network_interface.get(1).ip_address) + + # set simulation network as example network + sim.network = network + + return sim + + +def test_link_observation(simulation): + """Test the link observation.""" + # get a link + link: Link = next(iter(simulation.network.links.values())) + + computer: Computer = simulation.network.get_node_by_hostname("computer") + server: Server = simulation.network.get_node_by_hostname("server") + + simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep + + link_obs = LinkObservation(where=["network", "links", link.uuid]) + + assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10 + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state.get("PROTOCOLS") is not None + assert observation_state["PROTOCOLS"]["ALL"] == 0 + + computer.ping(server.network_interface.get(1).ip_address) + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state["PROTOCOLS"]["ALL"] == 1 diff --git a/tests/integration_tests/game_layer/observations/test_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py similarity index 76% rename from tests/integration_tests/game_layer/observations/test_observations.py rename to tests/integration_tests/game_layer/observations/test_nic_observations.py index 97df7882..c210b751 100644 --- a/tests/integration_tests/game_layer/observations/test_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -1,9 +1,27 @@ +from pathlib import Path +from typing import Union + import pytest +import yaml +from gymnasium import spaces from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.network.nmne import CAPTURE_NMNE from primaite.simulator.sim_container import Simulation +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) @pytest.fixture(scope="function") @@ -24,6 +42,10 @@ def test_nic(simulation): nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + assert nic_obs.space["nic_status"] == spaces.Discrete(3) + assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4) + assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4) + observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 1 # enabled assert observation_state.get("nmne") is not None diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 835202c6..b1563fbd 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -2,6 +2,7 @@ import copy from uuid import uuid4 import pytest +from gymnasium import spaces from primaite.game.agent.observations.node_observations import NodeObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -24,6 +25,8 @@ def test_node_observation(simulation): node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + assert node_obs.space["operating_status"] == spaces.Discrete(5) + observation_state = node_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 1 # computer is on diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 17fc386f..4ae0701e 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -1,4 +1,5 @@ import pytest +from gymnasium import spaces from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -29,6 +30,9 @@ def test_service_observation(simulation): service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + assert service_obs.space["operating_status"] == spaces.Discrete(7) + assert service_obs.space["health_status"] == spaces.Discrete(5) + observation_state = service_obs.observe(simulation.describe_state()) assert observation_state.get("health_status") == 0 From f2c6f10c21f445cf5d85db808ad4092ffa923993 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 12 Mar 2024 12:20:02 +0000 Subject: [PATCH 8/8] #2350: apply PR suggestions --- .../game/agent/observations/nic_observations.py | 10 +++++----- .../game/agent/observations/node_observations.py | 6 +++--- .../simulator/system/applications/database_client.py | 3 +-- .../red_applications/data_manipulation_bot.py | 3 ++- .../game_layer/observations/test_nic_observations.py | 10 +++++----- .../game_layer/observations/test_node_observations.py | 2 +- tests/integration_tests/network/test_capture_nmne.py | 4 ++-- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 735b41d4..de83e03a 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -27,7 +27,7 @@ class NicObservation(AbstractObservation): """The default NIC observation dict.""" data = {"nic_status": 0} if CAPTURE_NMNE: - data.update({"nmne": {"inbound": 0, "outbound": 0}}) + data.update({"NMNE": {"inbound": 0, "outbound": 0}}) return data @@ -133,14 +133,14 @@ class NicObservation(AbstractObservation): else: obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} if CAPTURE_NMNE: - obs_dict.update({"nmne": {}}) + obs_dict.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) inbound_count = inbound_keywords.get("*", 0) outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) self.nmne_inbound_last_step = inbound_count self.nmne_outbound_last_step = outbound_count return obs_dict @@ -151,7 +151,7 @@ class NicObservation(AbstractObservation): space = spaces.Dict({"nic_status": spaces.Discrete(3)}) if CAPTURE_NMNE: - space["nmne"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) return space diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index f211a6b5..94f0974b 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -86,7 +86,7 @@ class NodeObservation(AbstractObservation): self.default_observation: Dict = { "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, } if self.logon_status: @@ -111,7 +111,7 @@ class NodeObservation(AbstractObservation): obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} obs["operating_status"] = node_state["operating_state"] - obs["NETWORK_INTERFACES"] = { + obs["NICS"] = { i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) } @@ -127,7 +127,7 @@ class NodeObservation(AbstractObservation): "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), "operating_status": spaces.Discrete(5), - "NETWORK_INTERFACES": spaces.Dict( + "NICS": spaces.Dict( {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} ), } diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index bc51b3a2..d3afef59 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -48,6 +48,7 @@ class DatabaseClient(Application): def execute(self) -> bool: """Execution definition for db client: perform a select query.""" + self.num_executions += 1 # trying to connect counts as an execution if self.connections: can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1]) else: @@ -82,8 +83,6 @@ class DatabaseClient(Application): if not self._can_perform_action(): return False - self.num_executions += 1 # trying to connect counts as an execution - if not connection_id: connection_id = str(uuid4()) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 2a6c2b11..ee276971 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -193,6 +193,8 @@ class DataManipulationBot(Application): if not self._can_perform_action(): _LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.") self.run() + + self.num_executions += 1 return self._application_loop() def _application_loop(self) -> bool: @@ -202,7 +204,6 @@ class DataManipulationBot(Application): This is the core loop where the bot sequentially goes through the stages of the attack. """ if not self._can_perform_action(): - self.num_executions += 1 return False if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index c210b751..332bc1f7 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -43,14 +43,14 @@ def test_nic(simulation): nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) assert nic_obs.space["nic_status"] == spaces.Discrete(3) - assert nic_obs.space["nmne"]["inbound"] == spaces.Discrete(4) - assert nic_obs.space["nmne"]["outbound"] == spaces.Discrete(4) + assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) + assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) observation_state = nic_obs.observe(simulation.describe_state()) assert observation_state.get("nic_status") == 1 # enabled - assert observation_state.get("nmne") is not None - assert observation_state["nmne"].get("inbound") == 0 - assert observation_state["nmne"].get("outbound") == 0 + assert observation_state.get("NMNE") is not None + assert observation_state["NMNE"].get("inbound") == 0 + assert observation_state["NMNE"].get("outbound") == 0 nic.disable() observation_state = nic_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index b1563fbd..dce05b6a 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -32,7 +32,7 @@ def test_node_observation(simulation): assert observation_state.get("SERVICES") is not None assert observation_state.get("FOLDERS") is not None - assert observation_state.get("NETWORK_INTERFACES") is not None + assert observation_state.get("NICS") is not None # turn off computer pc.power_off() diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 85fcf102..9efc70f7 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -179,8 +179,8 @@ def test_capture_nmne_observations(uc2_network): # Observe the current state of NMNEs from the NICs of both the database and web servers state = sim.describe_state() - db_nic_obs = db_server_nic_obs.observe(state)["nmne"] - web_nic_obs = web_server_nic_obs.observe(state)["nmne"] + db_nic_obs = db_server_nic_obs.observe(state)["NMNE"] + web_nic_obs = web_server_nic_obs.observe(state)["NMNE"] # Define expected NMNE values based on the iteration count if i > 10: