diff --git a/docs/source/configuration/game.rst b/docs/source/configuration/game.rst index e43ea224..828571a7 100644 --- a/docs/source/configuration/game.rst +++ b/docs/source/configuration/game.rst @@ -23,6 +23,11 @@ This section defines high-level settings that apply across the game, currently i - ICMP - TCP - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 ``max_episode_length`` ---------------------- @@ -44,3 +49,8 @@ See :ref:`List of Ports ` for a list of ports. A list of protocols that the Reinforcement Learning agent(s) are able to see in the observation space. See :ref:`List of IPProtocols ` for a list of protocols. + +``thresholds`` +-------------- + +These are used to determine the thresholds of high, medium and low categories for counted observation occurrences. diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index dffb40ea..a3a7e44a 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -22,14 +22,17 @@ io_settings: game: max_episode_length: 256 ports: - - ARP - - DNS - HTTP - POSTGRES_SERVER protocols: - ICMP - TCP - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 agents: - ref: client_2_green_user diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 88848479..e641fabb 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction if TYPE_CHECKING: @@ -146,23 +146,10 @@ class AbstractAgent(ABC): class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" - pass - - -class RandomAgent(AbstractScriptedAgent): - """Agent that ignores its observation and acts completely at random.""" - + @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: - """Sample the action space randomly. - - :param obs: Current observation for this agent, not used in RandomAgent - :type obs: ObsType - :param timestep: The current simulation timestep, not used in RandomAgent - :type timestep: int - :return: Action formatted in CAOS format - :rtype: Tuple[str, Dict] - """ - return self.action_manager.get_action(self.action_manager.space.sample()) + """Return an action to be taken in the environment.""" + return super().get_action(obs=obs, timestep=timestep) class ProxyAgent(AbstractAgent): diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py deleted file mode 100644 index 83d1c4be..00000000 --- a/src/primaite/game/agent/observations.py +++ /dev/null @@ -1,1066 +0,0 @@ -"""Manages the observation space for the agent.""" -from abc import ABC, abstractmethod -from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING - -from gymnasium import spaces -from gymnasium.core import ObsType - -from primaite import getLogger -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.nmne import CAPTURE_NMNE - -_LOGGER = getLogger(__name__) - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - - -class AbstractObservation(ABC): - """Abstract class for an observation space component.""" - - @abstractmethod - def observe(self, state: Dict) -> Any: - """ - Return an observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Any - """ - pass - - @property - @abstractmethod - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space.""" - pass - - @classmethod - @abstractmethod - def from_config(cls, config: Dict, game: "PrimaiteGame"): - """Create this observation space component form a serialised format. - - The `game` parameter is for a the PrimaiteGame object that spawns this component. - """ - pass - - -class FileObservation(AbstractObservation): - """Observation of a file on a node in the network.""" - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """ - Initialise file observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',,'files',] - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.default_observation: spaces.Space = {"health_status": 0} - "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - file_state = access_from_nested_dict(state, self.where) - if file_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return {"health_status": file_state["visible_status"]} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"health_status": spaces.Discrete(6)}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": - """Create file observation from a config. - - :param config: Dictionary containing the configuration for this file observation. - :type config: Dict - :param game: _description_ - :type game: PrimaiteGame - :param parent_where: _description_, defaults to None - :type parent_where: _type_, optional - :return: _description_ - :rtype: _type_ - """ - return cls(where=parent_where + ["files", config["file_name"]]) - - -class ServiceObservation(AbstractObservation): - """Observation of a service in the network.""" - - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} - "Default observation is what should be returned when the service doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise service observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'services', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - service_state = access_from_nested_dict(state, self.where) - if service_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], - } - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ServiceObservation": - """Create service observation from a config. - - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation - :rtype: ServiceObservation - """ - return cls(where=parent_where + ["services", config["service_name"]]) - - -class LinkObservation(AbstractObservation): - """Observation of a link in the network.""" - - default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} - "Default observation is what should be returned when the link doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise link observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'servics', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - link_state = access_from_nested_dict(state, self.where) - if link_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - bandwidth = link_state["bandwidth"] - load = link_state["current_load"] - if load == 0: - utilisation_category = 0 - else: - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 9) + 1 - - # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": - """Create link observation from a config. - - :param config: Dictionary containing the configuration for this link observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed link observation - :rtype: LinkObservation - """ - return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) - - -class FolderObservation(AbstractObservation): - """Folder observation, including files inside of the folder.""" - - def __init__( - self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 - ) -> None: - """Initialise folder Observation, including files inside of the folder. - - :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',] - :type where: Optional[List[str]] - :param max_files: As size of the space must remain static, define max files that can be in this folder - , defaults to 5 - :type max_files: int, optional - :param file_positions: Defines the positioning within the observation space of particular files. This ensures - that even if new files are created, the existing files will always occupy the same space in the observation - space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the - observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same - name, it will take the position defined in this dict. Defaults to {} - :type file_positions: Dict[int, str], optional - """ - super().__init__() - - self.where: Optional[Tuple[str]] = where - - self.files: List[FileObservation] = files - while len(self.files) < num_files_per_folder: - self.files.append(FileObservation()) - while len(self.files) > num_files_per_folder: - truncated_file = self.files.pop() - msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warning(msg) - - self.default_observation = { - "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - folder_state = access_from_nested_dict(state, self.where) - if folder_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - health_status = folder_state["health_status"] - - obs = {} - - obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 - ) -> "FolderObservation": - """Create folder observation from a config. Also creates child file observations. - - :param config: Dictionary containing the configuration for this folder observation. Includes the name of the - folder and the files inside of it. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node ``where`` can be: - ['network','nodes',,'file_system'] - :type parent_where: Optional[List[str]] - :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed folder observation - :rtype: FolderObservation - """ - where = parent_where + ["folders", config["folder_name"]] - - file_configs = config["files"] - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] - - return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) - - -class NicObservation(AbstractObservation): - """Observation of a Network Interface Card (NIC) in the network.""" - - @property - def default_observation(self) -> Dict: - """The default NIC observation dict.""" - data = {"nic_status": 0} - if CAPTURE_NMNE: - data.update({"nmne": {"inbound": 0, "outbound": 0}}) - - return data - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise NIC observation. - - :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical - example may look like this: - ['network','nodes',,'NICs',] - If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. - :type where: Optional[Tuple[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - if CAPTURE_NMNE: - self.nmne_inbound_last_step: int = 0 - """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets - us find the difference.""" - self.nmne_outbound_last_step: int = 0 - """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets - us find the difference.""" - - def _categorise_mne_count(self, nmne_count: int) -> int: - """ - Categorise the number of Malicious Network Events (NMNEs) into discrete bins. - - This helps in classifying the severity or volume of MNEs into manageable levels for the agent. - - Bins are defined as follows: - - 0: No MNEs detected (0 events). - - 1: Low number of MNEs (1-5 events). - - 2: Moderate number of MNEs (6-10 events). - - 3: High number of MNEs (more than 10 events). - - :param nmne_count: Number of MNEs detected. - :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. - """ - if nmne_count > 10: - return 3 - elif nmne_count > 5: - return 2 - elif nmne_count > 0: - return 1 - return 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - nic_state = access_from_nested_dict(state, self.where) - - if nic_state is NOT_PRESENT_IN_STATE: - return self.default_observation - else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} - if CAPTURE_NMNE: - obs_dict.update({"nmne": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs_dict["nmne"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs_dict["nmne"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) - self.nmne_inbound_last_step = inbound_count - self.nmne_outbound_last_step = outbound_count - return obs_dict - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "nic_status": spaces.Discrete(3), - "nmne": spaces.Dict({"inbound": spaces.Discrete(6), "outbound": spaces.Discrete(6)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": - """Create NIC observation from a config. - - :param config: Dictionary containing the configuration for this NIC observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent - node. A typical location for a node ``where`` can be: ['network','nodes',] - :type parent_where: Optional[List[str]] - :return: Constructed NIC observation - :rtype: NicObservation - """ - return cls(where=parent_where + ["NICs", config["nic_num"]]) - - -class NodeObservation(AbstractObservation): - """Observation of a node in the network. Includes services, folders and NICs.""" - - def __init__( - self, - where: Optional[Tuple[str]] = None, - services: List[ServiceObservation] = [], - folders: List[FolderObservation] = [], - network_interfaces: List[NicObservation] = [], - logon_status: bool = False, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> None: - """ - Configurable observation for a node in the simulation. - - :param where: Where in the simulation state dictionary for find relevant information for this observation. - A typical location for a node looks like this: - ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] - :type where: List[str], optional - :param services: Mapping between position in observation space and service name, defaults to {} - :type services: Dict[int,str], optional - :param max_services: Max number of services that can be presented in observation space for this node - , defaults to 2 - :type max_services: int, optional - :param folders: Mapping between position in observation space and folder name, defaults to {} - :type folders: Dict[int,str], optional - :param max_folders: Max number of folders in this node's obs space, defaults to 2 - :type max_folders: int, optional - :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} - :type network_interfaces: Dict[int,str], optional - :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 - :type max_nics: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.services: List[ServiceObservation] = services - while len(self.services) < num_services_per_node: - # add empty service observation without `where` parameter so it always returns default (blank) observation - self.services.append(ServiceObservation()) - while len(self.services) > num_services_per_node: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - # truncate service list - - self.folders: List[FolderObservation] = folders - # add empty folder observation without `where` parameter that will always return default (blank) observations - while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) - while len(self.folders) > num_folders_per_node: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NicObservation] = network_interfaces - while len(self.network_interfaces) < num_nics_per_node: - self.network_interfaces.append(NicObservation()) - while len(self.network_interfaces) > num_nics_per_node: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warning(msg) - - self.logon_status: bool = logon_status - - self.default_observation: Dict = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, - } - if self.logon_status: - self.default_observation["logon_status"] = 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NETWORK_INTERFACES"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - - if self.logon_status: - obs["logon_status"] = 0 - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - space_shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NETWORK_INTERFACES": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - } - if self.logon_status: - space_shape["logon_status"] = spaces.Discrete(3) - - return spaces.Dict(space_shape) - - @classmethod - def from_config( - cls, - config: Dict, - game: "PrimaiteGame", - parent_where: Optional[List[str]] = None, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> "NodeObservation": - """Create node observation from a config. Also creates child service, folder and NIC observations. - - :param config: Dictionary containing the configuration for this node observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this node's parent - network. A typical location for it would be: ['network',] - :type parent_where: Optional[List[str]] - :param num_services_per_node: How many spaces for services are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_services_per_node: int, optional - :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_folders_per_node: int, optional - :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed node observation - :rtype: NodeObservation - """ - node_hostname = config["node_hostname"] - if parent_where is None: - where = ["network", "nodes", node_hostname] - else: - where = parent_where + ["nodes", node_hostname] - - svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] - folder_configs = config.get("folders", {}) - folders = [ - FolderObservation.from_config( - config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder - ) - for c in folder_configs - ] - # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. - nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] - network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] - logon_status = config.get("logon_status", False) - return cls( - where=where, - services=services, - folders=folders, - network_interfaces=network_interfaces, - logon_status=logon_status, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - - -class AclObservation(AbstractObservation): - """Observation of an Access Control List (ACL) in the network.""" - - # TODO: should where be optional, and we can use where=None to pad the observation space? - # definitely the current approach does not support tracking files that aren't specified by name, for example - # if a file is created at runtime, we have currently got no way of telling the observation space to track it. - # this needs adding, but not for the MVP. - def __init__( - self, - node_ip_to_id: Dict[str, int], - ports: List[int], - protocols: List[str], - where: Optional[Tuple[str]] = None, - num_rules: int = 10, - ) -> None: - """Initialise ACL observation. - - :param node_ip_to_id: Mapping between IP address and ID. - :type node_ip_to_id: Dict[str, int] - :param ports: List of ports which are part of the game that define the ordering when converting to an ID - :type ports: List[int] - :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID - :type protocols: list[str] - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical - example may look like this: - ['network','nodes',,'acl','acl'] - :type where: Optional[Tuple[str]], optional - :param num_rules: , defaults to 10 - :type num_rules: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.num_rules: int = num_rules - self.node_to_id: Dict[str, int] = node_ip_to_id - "List of node IP addresses, order in this list determines how they are converted to an ID" - self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} - "List of ports which are part of the game that define the ordering when converting to an ID" - self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} - "List of protocols which are part of the game, defines ordering when converting to an ID" - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - for i in range(self.num_rules) - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - # TODO: what if the ACL has more rules than num of max rules for obs space - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] - protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, - "protocol": protocol_id, - } - i += 1 - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": - """Generate ACL observation from a config. - - :param config: Dictionary containing the configuration for this ACL observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Observation object - :rtype: AclObservation - """ - max_acl_rules = config["options"]["max_acl_rules"] - node_ip_to_idx = {} - for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): - node_ref = ip_map_config["node_hostname"] - nic_num = ip_map_config["nic_num"] - node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] - nic_obj = node_obj.network_interface[nic_num] - node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - - router_hostname = config["router_hostname"] - return cls( - node_ip_to_id=node_ip_to_idx, - ports=game.options.ports, - protocols=game.options.protocols, - where=["network", "nodes", router_hostname, "acl", "acl"], - num_rules=max_acl_rules, - ) - - -class NullObservation(AbstractObservation): - """Null observation, returns a single 0 value for the observation space.""" - - def __init__(self, where: Optional[List[str]] = None): - """Initialise null observation.""" - self.default_observation: Dict = {} - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - return 0 - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Discrete(1) - - @classmethod - def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": - """ - Create null observation from a config. - - The parameters are ignored, they are here to match the signature of the other observation classes. - """ - return cls() - - -class ICSObservation(NullObservation): - """ICS observation placeholder, currently not implemented so always returns a single 0.""" - - pass - - -class UC2BlueObservation(AbstractObservation): - """Container for all observations used by the blue agent in UC2. - - TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler - for the purpose of compiling several observation components. - """ - - def __init__( - self, - nodes: List[NodeObservation], - links: List[LinkObservation], - acl: AclObservation, - ics: ICSObservation, - where: Optional[List[str]] = None, - ) -> None: - """Initialise UC2 blue observation. - - :param nodes: List of node observations - :type nodes: List[NodeObservation] - :param links: List of link observations - :type links: List[LinkObservation] - :param acl: The Access Control List observation - :type acl: AclObservation - :param ics: The ICS observation - :type ics: ICSObservation - :param where: Where in the simulation state dict to find information. Not used in this particular observation - because it only compiles other observations and doesn't contribute any new information, defaults to None - :type where: Optional[List[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.nodes: List[NodeObservation] = nodes - self.links: List[LinkObservation] = links - self.acl: AclObservation = acl - self.ics: ICSObservation = ics - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, - "ACL": self.acl.default_observation, - "ICS": self.ics.default_observation, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs["ACL"] = self.acl.observe(state) - obs["ICS"] = self.ics.observe(state) - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), - "ACL": self.acl.space, - "ICS": self.ics.space, - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": - """Create UC2 blue observation from a config. - - :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, - links, ACL and ICS observations. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed UC2 blue observation - :rtype: UC2BlueObservation - """ - node_configs = config["nodes"] - - num_services_per_node = config["num_services_per_node"] - num_folders_per_node = config["num_folders_per_node"] - num_files_per_folder = config["num_files_per_folder"] - num_nics_per_node = config["num_nics_per_node"] - nodes = [ - NodeObservation.from_config( - config=n, - game=game, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - for n in node_configs - ] - - link_configs = config["links"] - links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] - - acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, game=game) - - ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, game=game) - new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) - return new - - -class UC2RedObservation(AbstractObservation): - """Container for all observations used by the red agent in UC2.""" - - def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: - super().__init__() - self.where: Optional[List[str]] = where - self.nodes: List[NodeObservation] = nodes - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": - """ - Create UC2 red observation from a config. - - :param config: Dictionary containing the configuration for this UC2 red observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] - return cls(nodes=nodes, where=["network"]) - - -class UC2GreenObservation(NullObservation): - """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" - - pass - - -class ObservationManager: - """ - Manage the observations of an Agent. - - The observation space has the purpose of: - 1. Reading the outputted state from the PrimAITE Simulation. - 2. Selecting parts of the simulation state that are requested by the simulation config - 3. Formatting this information so an agent can use it to make decisions. - """ - - # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed - # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next - # refactor. - - def __init__(self, observation: AbstractObservation) -> None: - """Initialise observation space. - - :param observation: Observation object - :type observation: AbstractObservation - """ - self.obs: AbstractObservation = observation - self.current_observation: ObsType - - def update(self, state: Dict) -> Dict: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - """ - self.current_observation = self.obs.observe(state) - return self.current_observation - - @property - def space(self) -> None: - """Gymnasium space object describing the observation space shape.""" - return self.obs.space - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": - """Create observation space from a config. - - :param config: Dictionary containing the configuration for this observation space. - It should contain the key 'type' which selects which observation class to use (from a choice of: - UC2BlueObservation, UC2RedObservation, UC2GreenObservation) - The other key is 'options' which are passed to the constructor of the selected observation class. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) - else: - raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py new file mode 100644 index 00000000..70a83881 --- /dev/null +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -0,0 +1,188 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.observations import ( + AbstractObservation, + AclObservation, + ICSObservation, + LinkObservation, + NullObservation, +) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class UC2BlueObservation(AbstractObservation): + """Container for all observations used by the blue agent in UC2. + + TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler + for the purpose of compiling several observation components. + """ + + def __init__( + self, + nodes: List[NodeObservation], + links: List[LinkObservation], + acl: AclObservation, + ics: ICSObservation, + where: Optional[List[str]] = None, + ) -> None: + """Initialise UC2 blue observation. + + :param nodes: List of node observations + :type nodes: List[NodeObservation] + :param links: List of link observations + :type links: List[LinkObservation] + :param acl: The Access Control List observation + :type acl: AclObservation + :param ics: The ICS observation + :type ics: ICSObservation + :param where: Where in the simulation state dict to find information. Not used in this particular observation + because it only compiles other observations and doesn't contribute any new information, defaults to None + :type where: Optional[List[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.nodes: List[NodeObservation] = nodes + self.links: List[LinkObservation] = links + self.acl: AclObservation = acl + self.ics: ICSObservation = ics + + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, + "ACL": self.acl.default_observation, + "ICS": self.ics.default_observation, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + obs = {} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} + obs["ACL"] = self.acl.observe(state) + obs["ICS"] = self.ics.observe(state) + + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), + "ACL": self.acl.space, + "ICS": self.ics.space, + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": + """Create UC2 blue observation from a config. + + :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, + links, ACL and ICS observations. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed UC2 blue observation + :rtype: UC2BlueObservation + """ + node_configs = config["nodes"] + + num_services_per_node = config["num_services_per_node"] + num_folders_per_node = config["num_folders_per_node"] + num_files_per_folder = config["num_files_per_folder"] + num_nics_per_node = config["num_nics_per_node"] + nodes = [ + NodeObservation.from_config( + config=n, + game=game, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) + for n in node_configs + ] + + link_configs = config["links"] + links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] + + acl_config = config["acl"] + acl = AclObservation.from_config(config=acl_config, game=game) + + ics_config = config["ics"] + ics = ICSObservation.from_config(config=ics_config, game=game) + new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) + return new + + +class UC2RedObservation(AbstractObservation): + """Container for all observations used by the red agent in UC2.""" + + def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: + super().__init__() + self.where: Optional[List[str]] = where + self.nodes: List[NodeObservation] = nodes + + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + if self.where is None: + return self.default_observation + + obs = {} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": + """ + Create UC2 red observation from a config. + + :param config: Dictionary containing the configuration for this UC2 red observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + """ + node_configs = config["nodes"] + nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] + return cls(nodes=nodes, where=["network"]) + + +class UC2GreenObservation(NullObservation): + """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" + + pass diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py new file mode 100644 index 00000000..277bc51f --- /dev/null +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class FileObservation(AbstractObservation): + """Observation of a file on a node in the network.""" + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """ + Initialise file observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',,'files',] + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.default_observation: spaces.Space = {"health_status": 0} + "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + file_state = access_from_nested_dict(state, self.where) + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"health_status": file_state["visible_status"]} + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"health_status": spaces.Discrete(6)}) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": + """Create file observation from a config. + + :param config: Dictionary containing the configuration for this file observation. + :type config: Dict + :param game: _description_ + :type game: PrimaiteGame + :param parent_where: _description_, defaults to None + :type parent_where: _type_, optional + :return: _description_ + :rtype: _type_ + """ + return cls(where=parent_where + ["files", config["file_name"]]) + + +class FolderObservation(AbstractObservation): + """Folder observation, including files inside of the folder.""" + + def __init__( + self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 + ) -> None: + """Initialise folder Observation, including files inside the folder. + + :param where: Where in the simulation state dictionary to find the relevant information for this folder. + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',] + :type where: Optional[List[str]] + :param max_files: As size of the space must remain static, define max files that can be in this folder + , defaults to 5 + :type max_files: int, optional + :param file_positions: Defines the positioning within the observation space of particular files. This ensures + that even if new files are created, the existing files will always occupy the same space in the observation + space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the + observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same + name, it will take the position defined in this dict. Defaults to {} + :type file_positions: Dict[int, str], optional + """ + super().__init__() + + self.where: Optional[Tuple[str]] = where + + self.files: List[FileObservation] = files + while len(self.files) < num_files_per_folder: + self.files.append(FileObservation()) + while len(self.files) > num_files_per_folder: + truncated_file = self.files.pop() + msg = f"Too many files in folder observation. Truncating file {truncated_file}" + _LOGGER.warning(msg) + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + ) -> "FolderObservation": + """Create folder observation from a config. Also creates child file observations. + + :param config: Dictionary containing the configuration for this folder observation. Includes the name of the + folder and the files inside of it. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this folder's + parent node. A typical location for a node ``where`` can be: + ['network','nodes',,'file_system'] + :type parent_where: Optional[List[str]] + :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed folder observation + :rtype: FolderObservation + """ + where = parent_where + ["folders", config["folder_name"]] + + file_configs = config["files"] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] + + return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py new file mode 100644 index 00000000..de83e03a --- /dev/null +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -0,0 +1,188 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import CAPTURE_NMNE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NicObservation(AbstractObservation): + """Observation of a Network Interface Card (NIC) in the network.""" + + low_nmne_threshold: int = 0 + """The minimum number of malicious network events to be considered low.""" + med_nmne_threshold: int = 5 + """The minimum number of malicious network events to be considered medium.""" + high_nmne_threshold: int = 10 + """The minimum number of malicious network events to be considered high.""" + + global CAPTURE_NMNE + + @property + def default_observation(self) -> Dict: + """The default NIC observation dict.""" + data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"NMNE": {"inbound": 0, "outbound": 0}}) + + return data + + def __init__( + self, + where: Optional[Tuple[str]] = None, + low_nmne_threshold: Optional[int] = 0, + med_nmne_threshold: Optional[int] = 5, + high_nmne_threshold: Optional[int] = 10, + ) -> None: + """Initialise NIC observation. + + :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical + example may look like this: + ['network','nodes',,'NICs',] + If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. + :type where: Optional[Tuple[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + global CAPTURE_NMNE + if CAPTURE_NMNE: + self.nmne_inbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + self.nmne_outbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + + if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: + self._validate_nmne_categories( + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) + + def _validate_nmne_categories( + self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10 + ): + """ + Validates the nmne threshold config. + + If the configuration is valid, the thresholds will be set, otherwise, an exception is raised. + + :param: low_nmne_threshold: The minimum number of malicious network events to be considered low + :param: med_nmne_threshold: The minimum number of malicious network events to be considered medium + :param: high_nmne_threshold: The minimum number of malicious network events to be considered high + """ + if high_nmne_threshold <= med_nmne_threshold: + raise Exception( + f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater " + f"than medium nmne count ({med_nmne_threshold})" + ) + + if med_nmne_threshold <= low_nmne_threshold: + raise Exception( + f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater " + f"than low nmne count ({low_nmne_threshold})" + ) + + self.high_nmne_threshold = high_nmne_threshold + self.med_nmne_threshold = med_nmne_threshold + self.low_nmne_threshold = low_nmne_threshold + + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (default 1-5 events). + - 2: Moderate number of MNEs (default 6-10 events). + - 3: High number of MNEs (default more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > self.high_nmne_threshold: + return 3 + elif nmne_count > self.med_nmne_threshold: + return 2 + elif nmne_count > self.low_nmne_threshold: + return 1 + return 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + nic_state = access_from_nested_dict(state, self.where) + + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + else: + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"NMNE": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count + return obs_dict + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + space = spaces.Dict({"nic_status": spaces.Discrete(3)}) + + if CAPTURE_NMNE: + space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + + return space + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": + """Create NIC observation from a config. + + :param config: Dictionary containing the configuration for this NIC observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent + node. A typical location for a node ``where`` can be: ['network','nodes',] + :type parent_where: Optional[List[str]] + :return: Constructed NIC observation + :rtype: NicObservation + """ + low_nmne_threshold = None + med_nmne_threshold = None + high_nmne_threshold = None + + if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"): + threshold = game.options.thresholds["nmne"] + + low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None + med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None + high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None + + return cls( + where=parent_where + ["NICs", config["nic_num"]], + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py new file mode 100644 index 00000000..94f0974b --- /dev/null +++ b/src/primaite/game/agent/observations/node_observations.py @@ -0,0 +1,200 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.file_system_observations import FolderObservation +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.software_observation import ServiceObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NodeObservation(AbstractObservation): + """Observation of a node in the network. Includes services, folders and NICs.""" + + def __init__( + self, + where: Optional[Tuple[str]] = None, + services: List[ServiceObservation] = [], + folders: List[FolderObservation] = [], + network_interfaces: List[NicObservation] = [], + logon_status: bool = False, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> None: + """ + Configurable observation for a node in the simulation. + + :param where: Where in the simulation state dictionary for find relevant information for this observation. + A typical location for a node looks like this: + ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] + :type where: List[str], optional + :param services: Mapping between position in observation space and service name, defaults to {} + :type services: Dict[int,str], optional + :param max_services: Max number of services that can be presented in observation space for this node + , defaults to 2 + :type max_services: int, optional + :param folders: Mapping between position in observation space and folder name, defaults to {} + :type folders: Dict[int,str], optional + :param max_folders: Max number of folders in this node's obs space, defaults to 2 + :type max_folders: int, optional + :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} + :type network_interfaces: Dict[int,str], optional + :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 + :type max_nics: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.services: List[ServiceObservation] = services + while len(self.services) < num_services_per_node: + # add empty service observation without `where` parameter so it always returns default (blank) observation + self.services.append(ServiceObservation()) + while len(self.services) > num_services_per_node: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + # truncate service list + + self.folders: List[FolderObservation] = folders + # add empty folder observation without `where` parameter that will always return default (blank) observations + while len(self.folders) < num_folders_per_node: + self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) + while len(self.folders) > num_folders_per_node: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" + _LOGGER.warning(msg) + + self.network_interfaces: List[NicObservation] = network_interfaces + while len(self.network_interfaces) < num_nics_per_node: + self.network_interfaces.append(NicObservation()) + while len(self.network_interfaces) > num_nics_per_node: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" + _LOGGER.warning(msg) + + self.logon_status: bool = logon_status + + self.default_observation: Dict = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + } + if self.logon_status: + self.default_observation["logon_status"] = 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + + if self.logon_status: + obs["logon_status"] = 0 + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + space_shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + } + if self.logon_status: + space_shape["logon_status"] = spaces.Discrete(3) + + return spaces.Dict(space_shape) + + @classmethod + def from_config( + cls, + config: Dict, + game: "PrimaiteGame", + parent_where: Optional[List[str]] = None, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> "NodeObservation": + """Create node observation from a config. Also creates child service, folder and NIC observations. + + :param config: Dictionary containing the configuration for this node observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this node's parent + network. A typical location for it would be: ['network',] + :type parent_where: Optional[List[str]] + :param num_services_per_node: How many spaces for services are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_services_per_node: int, optional + :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_folders_per_node: int, optional + :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed node observation + :rtype: NodeObservation + """ + node_hostname = config["node_hostname"] + if parent_where is None: + where = ["network", "nodes", node_hostname] + else: + where = parent_where + ["nodes", node_hostname] + + svc_configs = config.get("services", {}) + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] + folder_configs = config.get("folders", {}) + folders = [ + FolderObservation.from_config( + config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder + ) + for c in folder_configs + ] + # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. + nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] + network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] + logon_status = config.get("logon_status", False) + return cls( + where=where, + services=services, + folders=folders, + network_interfaces=network_interfaces, + logon_status=logon_status, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py new file mode 100644 index 00000000..400345fa --- /dev/null +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -0,0 +1,73 @@ +from typing import Dict, TYPE_CHECKING + +from gymnasium.core import ObsType + +from primaite.game.agent.observations.agent_observations import ( + UC2BlueObservation, + UC2GreenObservation, + UC2RedObservation, +) +from primaite.game.agent.observations.observations import AbstractObservation + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ObservationManager: + """ + Manage the observations of an Agent. + + The observation space has the purpose of: + 1. Reading the outputted state from the PrimAITE Simulation. + 2. Selecting parts of the simulation state that are requested by the simulation config + 3. Formatting this information so an agent can use it to make decisions. + """ + + # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed + # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next + # refactor. + + def __init__(self, observation: AbstractObservation) -> None: + """Initialise observation space. + + :param observation: Observation object + :type observation: AbstractObservation + """ + self.obs: AbstractObservation = observation + self.current_observation: ObsType + + def update(self, state: Dict) -> Dict: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + """ + self.current_observation = self.obs.observe(state) + return self.current_observation + + @property + def space(self) -> None: + """Gymnasium space object describing the observation space shape.""" + return self.obs.space + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": + """Create observation space from a config. + + :param config: Dictionary containing the configuration for this observation space. + It should contain the key 'type' which selects which observation class to use (from a choice of: + UC2BlueObservation, UC2RedObservation, UC2GreenObservation) + The other key is 'options' which are passed to the constructor of the selected observation class. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + """ + if config["type"] == "UC2BlueObservation": + return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) + elif config["type"] == "UC2RedObservation": + return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) + elif config["type"] == "UC2GreenObservation": + return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) + else: + raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py new file mode 100644 index 00000000..6236b00d --- /dev/null +++ b/src/primaite/game/agent/observations/observations.py @@ -0,0 +1,309 @@ +"""Manages the observation space for the agent.""" +from abc import ABC, abstractmethod +from ipaddress import IPv4Address +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class AbstractObservation(ABC): + """Abstract class for an observation space component.""" + + @abstractmethod + def observe(self, state: Dict) -> Any: + """ + Return an observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Any + """ + pass + + @property + @abstractmethod + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space.""" + pass + + @classmethod + @abstractmethod + def from_config(cls, config: Dict, game: "PrimaiteGame"): + """Create this observation space component form a serialised format. + + The `game` parameter is for a the PrimaiteGame object that spawns this component. + """ + pass + + +class LinkObservation(AbstractObservation): + """Observation of a link in the network.""" + + default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} + "Default observation is what should be returned when the link doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise link observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'servics', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + link_state = access_from_nested_dict(state, self.where) + if link_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + bandwidth = link_state["bandwidth"] + load = link_state["current_load"] + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 9) + 1 + + # TODO: once the links support separte load per protocol, this needs amendment to reflect that. + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": + """Create link observation from a config. + + :param config: Dictionary containing the configuration for this link observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed link observation + :rtype: LinkObservation + """ + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) + + +class AclObservation(AbstractObservation): + """Observation of an Access Control List (ACL) in the network.""" + + # TODO: should where be optional, and we can use where=None to pad the observation space? + # definitely the current approach does not support tracking files that aren't specified by name, for example + # if a file is created at runtime, we have currently got no way of telling the observation space to track it. + # this needs adding, but not for the MVP. + def __init__( + self, + node_ip_to_id: Dict[str, int], + ports: List[int], + protocols: List[str], + where: Optional[Tuple[str]] = None, + num_rules: int = 10, + ) -> None: + """Initialise ACL observation. + + :param node_ip_to_id: Mapping between IP address and ID. + :type node_ip_to_id: Dict[str, int] + :param ports: List of ports which are part of the game that define the ordering when converting to an ID + :type ports: List[int] + :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID + :type protocols: list[str] + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical + example may look like this: + ['network','nodes',,'acl','acl'] + :type where: Optional[Tuple[str]], optional + :param num_rules: , defaults to 10 + :type num_rules: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.num_rules: int = num_rules + self.node_to_id: Dict[str, int] = node_ip_to_id + "List of node IP addresses, order in this list determines how they are converted to an ID" + self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} + "List of ports which are part of the game that define the ordering when converting to an ID" + self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} + "List of protocols which are part of the game, defines ordering when converting to an ID" + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + # TODO: what if the ACL has more rules than num of max rules for obs space + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, + } + i += 1 + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": + """Generate ACL observation from a config. + + :param config: Dictionary containing the configuration for this ACL observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Observation object + :rtype: AclObservation + """ + max_acl_rules = config["options"]["max_acl_rules"] + node_ip_to_idx = {} + for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): + node_ref = ip_map_config["node_hostname"] + nic_num = ip_map_config["nic_num"] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] + nic_obj = node_obj.network_interface[nic_num] + node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 + + router_hostname = config["router_hostname"] + return cls( + node_ip_to_id=node_ip_to_idx, + ports=game.options.ports, + protocols=game.options.protocols, + where=["network", "nodes", router_hostname, "acl", "acl"], + num_rules=max_acl_rules, + ) + + +class NullObservation(AbstractObservation): + """Null observation, returns a single 0 value for the observation space.""" + + def __init__(self, where: Optional[List[str]] = None): + """Initialise null observation.""" + self.default_observation: Dict = {} + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + return 0 + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Discrete(1) + + @classmethod + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": + """ + Create null observation from a config. + + The parameters are ignored, they are here to match the signature of the other observation classes. + """ + return cls() + + +class ICSObservation(NullObservation): + """ICS observation placeholder, currently not implemented so always returns a single 0.""" + + pass diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py new file mode 100644 index 00000000..6caf791c --- /dev/null +++ b/src/primaite/game/agent/observations/software_observation.py @@ -0,0 +1,163 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ServiceObservation(AbstractObservation): + """Observation of a service in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} + "Default observation is what should be returned when the service doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise service observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'services', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": service_state["operating_state"], + "health_status": service_state["health_state_visible"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ServiceObservation": + """Create service observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ServiceObservation + """ + return cls(where=parent_where + ["services", config["service_name"]]) + + +class ApplicationObservation(AbstractObservation): + """Observation of an application in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0} + "Default observation is what should be returned when the application doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise application observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'applications', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + app_state = access_from_nested_dict(state, self.where) + if app_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": app_state["operating_state"], + "health_status": app_state["health_state_visible"], + "num_executions": self._categorise_num_executions(app_state["num_executions"]), + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(6), + "num_executions": spaces.Discrete(4), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ApplicationObservation": + """Create application observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ApplicationObservation + """ + return cls(where=parent_where + ["services", config["application_name"]]) + + @classmethod + def _categorise_num_executions(cls, num_executions: int) -> int: + """ + Categorise the number of executions of an application. + + Helps classify the number of application executions into different categories. + + Current categories: + - 0: Application is never executed + - 1: Application is executed a low number of times (1-5) + - 2: Application is executed often (6-10) + - 3: Application is executed a high number of times (more than 10) + + :param: num_executions: Number of times the application is executed + """ + if num_executions > 10: + return 3 + elif num_executions > 5: + return 2 + elif num_executions > 0: + return 1 + return 0 diff --git a/src/primaite/game/agent/scripted_agents/__init__.py b/src/primaite/game/agent/scripted_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py similarity index 100% rename from src/primaite/game/agent/data_manipulation_bot.py rename to src/primaite/game/agent/scripted_agents/data_manipulation_bot.py diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py similarity index 97% rename from src/primaite/game/agent/scripted_agents.py rename to src/primaite/game/agent/scripted_agents/probabilistic_agent.py index 5111df32..9cddc978 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -7,7 +7,7 @@ from gymnasium.core import ObsType from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractScriptedAgent -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py new file mode 100644 index 00000000..34a4b5ac --- /dev/null +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -0,0 +1,21 @@ +from typing import Dict, Tuple + +from gymnasium.core import ObsType + +from primaite.game.agent.interface import AbstractScriptedAgent + + +class RandomAgent(AbstractScriptedAgent): + """Agent that ignores its observation and acts completely at random.""" + + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Sample the action space randomly. + + :param obs: Current observation for this agent, not used in RandomAgent + :type obs: ObsType + :param timestep: The current simulation timestep, not used in RandomAgent + :type timestep: int + :return: Action formatted in CAOS format + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.action_manager.space.sample()) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c94cb3ad..8d3f8f5e 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,16 +1,16 @@ """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.agent.scripted_agents import ProbabilisticAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") max_episode_length: int = 256 + """Maximum number of episodes for the PrimAITE game.""" ports: List[str] + """A whitelist of available ports in the simulation.""" protocols: List[str] + """A whitelist of available protocols in the simulation.""" + thresholds: Optional[Dict] = {} + """A dict containing the thresholds used for determining what is acceptable during observations.""" class PrimaiteGame: diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 4dc222fb..9331c40c 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -38,6 +38,8 @@ class File(FileSystemItemABC): "The Path if real is True." sim_root: Optional[Path] = None "Root path of the simulation." + num_access: int = 0 + "Number of times the file was accessed in the current step." def __init__(self, **kwargs): """ @@ -93,11 +95,23 @@ class File(FileSystemItemABC): return os.path.getsize(self.sim_path) return self.sim_size + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the file. + + :param timestep: The current timestep of the simulation. + """ + super().apply_timestep(timestep=timestep) + + # reset the number of accesses to 0 + self.num_access = 0 + def describe_state(self) -> Dict: """Produce a dictionary describing the current state of this object.""" state = super().describe_state() state["size"] = self.size state["file_type"] = self.file_type.name + state["num_access"] = self.num_access return state def scan(self) -> bool: @@ -106,6 +120,7 @@ class File(FileSystemItemABC): self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}") return False + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}") self.visible_health_status = self.health_status @@ -162,6 +177,7 @@ class File(FileSystemItemABC): if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}") return True @@ -176,6 +192,7 @@ class File(FileSystemItemABC): if self.health_status == FileSystemItemHealthStatus.GOOD: self.health_status = FileSystemItemHealthStatus.CORRUPT + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}") return True @@ -189,6 +206,7 @@ class File(FileSystemItemABC): if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}") return True @@ -199,6 +217,7 @@ class File(FileSystemItemABC): self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}") return False + self.num_access += 1 # file was accessed self.deleted = True self.sys_log.info(f"File deleted {self.folder_name}/{self.name}") return True diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ade03412..9166178c 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -28,6 +28,10 @@ class FileSystem(SimComponent): "Instance of SysLog used to create system logs." sim_root: Path "Root path of the simulation." + num_file_creations: int = 0 + "Number of file creations in the current step." + num_file_deletions: int = 0 + "Number of file deletions in the current step." def __init__(self, **kwargs): super().__init__(**kwargs) @@ -264,6 +268,8 @@ class FileSystem(SimComponent): ) folder.add_file(file) self._file_request_manager.add_request(name=file.name, request_type=RequestType(func=file._request_manager)) + # increment file creation + self.num_file_creations += 1 return file def get_file(self, folder_name: str, file_name: str, include_deleted: Optional[bool] = False) -> Optional[File]: @@ -324,6 +330,8 @@ class FileSystem(SimComponent): if folder: file = folder.get_file(file_name) if file: + # increment file creation + self.num_file_deletions += 1 folder.remove_file(file) return True return False @@ -355,15 +363,14 @@ class FileSystem(SimComponent): """ file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) if file: - src_folder = file.folder - # remove file from src - src_folder.remove_file(file) + self.delete_file(folder_name=file.folder_name, file_name=file.name) dst_folder = self.get_folder(folder_name=dst_folder_name) if not dst_folder: dst_folder = self.create_folder(dst_folder_name) # add file to dst dst_folder.add_file(file) + self.num_file_creations += 1 if file.real: old_sim_path = file.sim_path file.sim_path = file.sim_root / file.path @@ -391,6 +398,10 @@ class FileSystem(SimComponent): folder_name=dst_folder.name, **file.model_dump(exclude={"uuid", "folder_id", "folder_name", "sim_path"}), ) + self.num_file_creations += 1 + # increment access counter + file.num_access += 1 + dst_folder.add_file(file_copy, force=True) if file.real: @@ -408,12 +419,20 @@ class FileSystem(SimComponent): state = super().describe_state() state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()} state["deleted_folders"] = {folder.name: folder.describe_state() for folder in self.deleted_folders.values()} + state["num_file_creations"] = self.num_file_creations + state["num_file_deletions"] = self.num_file_deletions return state def apply_timestep(self, timestep: int) -> None: """Apply time step to FileSystem and its child folders and files.""" super().apply_timestep(timestep=timestep) + # reset number of file creations + self.num_file_creations = 0 + + # reset number of file deletions + self.num_file_deletions = 0 + # apply timestep to folders for folder_id in self.folders: self.folders[folder_id].apply_timestep(timestep=timestep) diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index c3ddff8a..6ebd8d14 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -138,7 +138,8 @@ class Folder(FileSystemItemABC): file = self.get_file_by_id(file_uuid=file_id) file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: - self.visible_health_status = FileSystemItemHealthStatus.CORRUPT + self.health_status = FileSystemItemHealthStatus.CORRUPT + self.visible_health_status = self.health_status def _reveal_to_red_timestep(self) -> None: """Apply reveal to red timestep.""" diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 513606a9..74013681 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -59,6 +59,16 @@ class Application(IOSoftware): ) return state + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the application. + + :param timestep: The current timestep of the simulation. + """ + super().apply_timestep(timestep=timestep) + + self.num_executions = 0 # reset number of executions + def _can_perform_action(self) -> bool: """ Checks if the application can perform actions. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index de7103f7..d3afef59 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -48,6 +48,7 @@ class DatabaseClient(Application): def execute(self) -> bool: """Execution definition for db client: perform a select query.""" + self.num_executions += 1 # trying to connect counts as an execution if self.connections: can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1]) else: diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 23e69e4d..ee276971 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -193,6 +193,8 @@ class DataManipulationBot(Application): if not self._can_perform_action(): _LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.") self.run() + + self.num_executions += 1 return self._application_loop() def _application_loop(self) -> bool: diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index a26570ed..e669ca32 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -89,6 +89,8 @@ class WebBrowser(Application): if not self._can_perform_action(): return False + self.num_executions += 1 # trying to connect counts as an execution + # reset latest response self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND) diff --git a/tests/conftest.py b/tests/conftest.py index a117a1ef..20600e73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,8 @@ from _pytest.monkeypatch import MonkeyPatch from primaite import getLogger, PRIMAITE_PATHS from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent -from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations.observations import ICSObservation from primaite.game.agent.rewards import RewardFunction from primaite.game.game import PrimaiteGame from primaite.session.session import PrimaiteSession diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 3aff59af..a5fcb372 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -5,8 +5,9 @@ from typing import Union import yaml from primaite.config.load import data_manipulation_config_path -from primaite.game.agent.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.interface import ProxyAgent, RandomAgent +from primaite.game.agent.interface import ProxyAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -43,15 +44,15 @@ def test_example_config(): # green agent 1 assert "client_2_green_user" in game.agents - assert isinstance(game.agents["client_2_green_user"], RandomAgent) + assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent) # green agent 2 assert "client_1_green_user" in game.agents - assert isinstance(game.agents["client_1_green_user"], RandomAgent) + assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent) # red agent - assert "client_1_data_manipulation_red_bot" in game.agents - assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent) + assert "data_manipulation_attacker" in game.agents + assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent) # blue agent assert "defender" in game.agents diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py new file mode 100644 index 00000000..adbbf2b5 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -0,0 +1,25 @@ +from pathlib import Path +from typing import Union + +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_thresholds(): + """Test that the game options can be parsed correctly.""" + game = load_config(data_manipulation_config_path()) + + assert game.options.thresholds is not None diff --git a/tests/integration_tests/game_layer/observations/__init__.py b/tests/integration_tests/game_layer/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py new file mode 100644 index 00000000..93867edd --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -0,0 +1,66 @@ +import pytest + +from primaite.game.agent.observations.observations import AclObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_acl_observations(simulation): + """Test the ACL rule observations.""" + router: Router = simulation.network.get_node_by_hostname("router_1") + client_1: Computer = simulation.network.get_node_by_hostname("client_1") + server: Computer = simulation.network.get_node_by_hostname("server_1") + + # quick set up of ntp + client_1.software_manager.install(NTPClient) + ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient") + ntp_client.configure(server.network_interface.get(1).ip_address) + server.software_manager.install(NTPServer) + + # add router acl rule + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + + acl_obs = AclObservation( + where=["network", "nodes", router.hostname, "acl", "acl"], + node_ip_to_id={}, + ports=["NTP", "HTTP", "POSTGRES_SERVER"], + protocols=["TCP", "UDP", "ICMP"], + ) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1) + assert rule_obs.get("permission") == 1 # permit = 1 deny = 2 + assert rule_obs.get("source_node_id") == 1 # applies to all source nodes + assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes + assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs) + assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2 + assert rule_obs.get("protocol") == 1 # 1 = No Protocol + + router.acl.remove_rule(1) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 + assert rule_obs.get("permission") == 0 + assert rule_obs.get("source_node_id") == 0 + assert rule_obs.get("dest_node_id") == 0 + assert rule_obs.get("source_port") == 0 + assert rule_obs.get("dest_port") == 0 + assert rule_obs.get("protocol") == 0 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py new file mode 100644 index 00000000..35bb95fd --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -0,0 +1,70 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_file_observation(simulation): + """Test the file observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file on the pc + file = pc.file_system.create_file(file_name="dog.png") + + dog_file_obs = FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] + ) + + assert dog_file_obs.space["health_status"] == spaces.Discrete(6) + + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # good initial + + file.corrupt() + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan file so this changes + + file.scan() + file.apply_timestep(0) # apply time step + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # corrupted + + +def test_folder_observation(simulation): + """Test the folder observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file and folder on the pc + folder = pc.file_system.create_folder("test_folder") + file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") + + root_folder_obs = FolderObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] + ) + + assert root_folder_obs.space["health_status"] == spaces.Discrete(6) + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("FILES") is not None + assert observation_state.get("health_status") == 1 + + file.corrupt() # corrupt just the file + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan folder to change this + + folder.scan() + for i in range(folder.scan_duration + 1): + folder.apply_timestep(i) # apply as many timesteps as needed for a scan + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py new file mode 100644 index 00000000..bfe4d5cc --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -0,0 +1,73 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.observations import LinkObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation() -> Simulation: + sim = Simulation() + + network = Network() + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Connect Computer and Server + network.connect(computer.network_interface[1], server.network_interface[1]) + + # Should be linked + assert next(iter(network.links.values())).is_up + + assert computer.ping(server.network_interface.get(1).ip_address) + + # set simulation network as example network + sim.network = network + + return sim + + +def test_link_observation(simulation): + """Test the link observation.""" + # get a link + link: Link = next(iter(simulation.network.links.values())) + + computer: Computer = simulation.network.get_node_by_hostname("computer") + server: Server = simulation.network.get_node_by_hostname("server") + + simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep + + link_obs = LinkObservation(where=["network", "links", link.uuid]) + + assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10 + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state.get("PROTOCOLS") is not None + assert observation_state["PROTOCOLS"]["ALL"] == 0 + + computer.ping(server.network_interface.get(1).ip_address) + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state["PROTOCOLS"]["ALL"] == 1 diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py new file mode 100644 index 00000000..332bc1f7 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -0,0 +1,97 @@ +from pathlib import Path +from typing import Union + +import pytest +import yaml +from gymnasium import spaces + +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.network.nmne import CAPTURE_NMNE +from primaite.simulator.sim_container import Simulation +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_nic(simulation): + """Test the NIC observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic: NIC = pc.network_interface[1] + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + assert nic_obs.space["nic_status"] == spaces.Discrete(3) + assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) + assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) + + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 1 # enabled + assert observation_state.get("NMNE") is not None + assert observation_state["NMNE"].get("inbound") == 0 + assert observation_state["NMNE"].get("outbound") == 0 + + nic.disable() + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 2 # disabled + + +def test_nic_categories(simulation): + """Test the NIC observation nmne count categories.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + assert nic_obs.high_nmne_threshold == 10 # default + assert nic_obs.med_nmne_threshold == 5 # default + assert nic_obs.low_nmne_threshold == 0 # default + + nic_obs = NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + assert nic_obs.high_nmne_threshold == 9 + assert nic_obs.med_nmne_threshold == 6 + assert nic_obs.low_nmne_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=9, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=9, + high_nmne_threshold=9, + ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py new file mode 100644 index 00000000..dce05b6a --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -0,0 +1,46 @@ +import copy +from uuid import uuid4 + +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_node_observation(simulation): + """Test a Node observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + + assert node_obs.space["operating_status"] == spaces.Discrete(5) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 1 # computer is on + + assert observation_state.get("SERVICES") is not None + assert observation_state.get("FOLDERS") is not None + assert observation_state.get("NICS") is not None + + # turn off computer + pc.power_off() + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 4 # shutting down + + for i in range(pc.shut_down_duration + 1): + pc.apply_timestep(i) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 2 diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py new file mode 100644 index 00000000..4ae0701e --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -0,0 +1,70 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_service_observation(simulation): + """Test the service observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(NTPServer) + + ntp_server = pc.software_manager.software.get("NTPServer") + assert ntp_server + + service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + + assert service_obs.space["operating_status"] == spaces.Discrete(7) + assert service_obs.space["health_status"] == spaces.Discrete(5) + + observation_state = service_obs.observe(simulation.describe_state()) + + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 1 # running + + ntp_server.restart() + observation_state = service_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 6 # resetting + + +def test_application_observation(simulation): + """Test the application observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(DatabaseClient) + + web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") + assert web_browser + + app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + + web_browser.close() + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 2 # stopped + assert observation_state.get("num_executions") == 0 + + web_browser.run() + web_browser.scan() # scan to update health status + web_browser.get_webpage("test") + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 + assert observation_state.get("operating_status") == 1 # running + assert observation_state.get("num_executions") == 1 diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 8911632c..740fb491 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -10,28 +10,14 @@ # 4. Check that the simulation has changed in the way that I expect. # 5. Repeat for all actions. -from typing import Dict, Tuple +from typing import Tuple import pytest -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, ProxyAgent -from primaite.game.agent.observations import ICSObservation, ObservationManager -from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.simulator.system.services.dns.dns_server import DNSServer -from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index d1301759..f52b52f7 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -1,6 +1,6 @@ from gymnasium import spaces -from primaite.game.agent.observations import FileObservation +from primaite.game.agent.observations.file_system_observations import FileObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 43bb176b..9efc70f7 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,4 +1,4 @@ -from primaite.game.agent.observations import NicObservation +from primaite.game.agent.observations.nic_observations import NicObservation from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.sim_container import Simulation @@ -179,8 +179,8 @@ def test_capture_nmne_observations(uc2_network): # Observe the current state of NMNEs from the NICs of both the database and web servers state = sim.describe_state() - db_nic_obs = db_server_nic_obs.observe(state)["nmne"] - web_nic_obs = web_server_nic_obs.observe(state)["nmne"] + db_nic_obs = db_server_nic_obs.observe(state)["NMNE"] + web_nic_obs = web_server_nic_obs.observe(state)["NMNE"] # Define expected NMNE values based on the iteration count if i > 10: diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index 73228e36..c556cfad 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -1,7 +1,8 @@ from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations.observations import ICSObservation from primaite.game.agent.rewards import RewardFunction -from primaite.game.agent.scripted_agents import ProbabilisticAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent def test_probabilistic_agent(): diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 4defc80c..05824834 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,7 +1,9 @@ import pytest +from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.file_system.file_type import FileType +from primaite.simulator.file_system.folder import Folder def test_create_folder_and_file(file_system): @@ -14,8 +16,15 @@ def test_create_folder_and_file(file_system): assert len(file_system.get_folder("test_folder").files) == 1 + assert file_system.num_file_creations == 1 + assert file_system.get_folder("test_folder").get_file("test_file.txt") + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) @@ -23,24 +32,37 @@ def test_create_file_no_folder(file_system): """Tests that creating a file without a folder creates a folder and sets that as the file's parent.""" file = file_system.create_file(file_name="test_file.txt", size=10) assert len(file_system.folders) is 1 + assert file_system.num_file_creations == 1 assert file_system.get_folder("root").get_file("test_file.txt") == file assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT assert file_system.get_folder("root").get_file("test_file.txt").size == 10 + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) def test_delete_file(file_system): """Tests that a file can be deleted.""" - file_system.create_file(file_name="test_file.txt") + file = file_system.create_file(file_name="test_file.txt") assert len(file_system.folders) == 1 assert len(file_system.get_folder("root").files) == 1 file_system.delete_file(folder_name="root", file_name="test_file.txt") + assert file.num_access == 1 + assert file_system.num_file_deletions == 1 assert len(file_system.folders) == 1 assert len(file_system.get_folder("root").files) == 0 assert len(file_system.get_folder("root").deleted_files) == 1 + file_system.apply_timestep(0) + + # num file deletions should reset + assert file_system.num_file_deletions == 0 + file_system.show(full=True) @@ -54,6 +76,7 @@ def test_delete_non_existent_file(file_system): # deleting should not change how many files are in folder file_system.delete_file(folder_name="root", file_name="does_not_exist!") + assert file_system.num_file_deletions == 0 # should still only be one folder assert len(file_system.folders) == 1 @@ -96,6 +119,7 @@ def test_create_duplicate_file(file_system): assert len(file_system.folders) is 2 file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + assert file_system.num_file_creations == 1 assert len(file_system.get_folder("test_folder").files) == 1 @@ -103,6 +127,7 @@ def test_create_duplicate_file(file_system): file_system.create_file(file_name="test_file.txt", folder_name="test_folder") assert len(file_system.get_folder("test_folder").files) == 1 + assert file_system.num_file_creations == 1 file_system.show(full=True) @@ -136,13 +161,24 @@ def test_move_file(file_system): assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 0 + assert file_system.num_file_deletions == 0 + assert file_system.num_file_creations == 1 file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") + assert file_system.num_file_deletions == 1 + assert file_system.num_file_creations == 2 + assert file.num_access == 1 assert len(file_system.get_folder("src_folder").files) == 0 assert len(file_system.get_folder("dst_folder").files) == 1 assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid + file_system.apply_timestep(0) + + # num file creations and deletions should reset + assert file_system.num_file_creations == 0 + assert file_system.num_file_deletions == 0 + file_system.show(full=True) @@ -152,17 +188,25 @@ def test_copy_file(file_system): file_system.create_folder(folder_name="dst_folder") file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True) + assert file_system.num_file_creations == 1 original_uuid = file.uuid assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 0 file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") + assert file_system.num_file_creations == 2 + assert file.num_access == 1 assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 1 assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) @@ -172,13 +216,17 @@ def test_get_file(file_system): file1: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") file2: File = file_system.create_file(file_name="test_file2.txt", folder_name="test_folder") - folder.remove_file(file2) + file_system.delete_file("test_folder", "test_file2.txt") + # file 2 was accessed before being deleted + assert file2.num_access == 1 assert file_system.get_file_by_id(file_uuid=file1.uuid, folder_uuid=folder.uuid) is not None assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid) is None assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid, include_deleted=True) is not None assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None + assert file2.num_access == 1 # cannot access deleted file + file_system.delete_folder(folder_name="test_folder") assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None