diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py index 522cdb59..70a83881 100644 --- a/src/primaite/game/agent/observations/agent_observations.py +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -2,12 +2,12 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces +from primaite.game.agent.observations.node_observations import NodeObservation from primaite.game.agent.observations.observations import ( AbstractObservation, AclObservation, ICSObservation, LinkObservation, - NodeObservation, NullObservation, ) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py new file mode 100644 index 00000000..277bc51f --- /dev/null +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class FileObservation(AbstractObservation): + """Observation of a file on a node in the network.""" + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """ + Initialise file observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',,'files',] + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.default_observation: spaces.Space = {"health_status": 0} + "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + file_state = access_from_nested_dict(state, self.where) + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"health_status": file_state["visible_status"]} + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"health_status": spaces.Discrete(6)}) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": + """Create file observation from a config. + + :param config: Dictionary containing the configuration for this file observation. + :type config: Dict + :param game: _description_ + :type game: PrimaiteGame + :param parent_where: _description_, defaults to None + :type parent_where: _type_, optional + :return: _description_ + :rtype: _type_ + """ + return cls(where=parent_where + ["files", config["file_name"]]) + + +class FolderObservation(AbstractObservation): + """Folder observation, including files inside of the folder.""" + + def __init__( + self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 + ) -> None: + """Initialise folder Observation, including files inside the folder. + + :param where: Where in the simulation state dictionary to find the relevant information for this folder. + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',] + :type where: Optional[List[str]] + :param max_files: As size of the space must remain static, define max files that can be in this folder + , defaults to 5 + :type max_files: int, optional + :param file_positions: Defines the positioning within the observation space of particular files. This ensures + that even if new files are created, the existing files will always occupy the same space in the observation + space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the + observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same + name, it will take the position defined in this dict. Defaults to {} + :type file_positions: Dict[int, str], optional + """ + super().__init__() + + self.where: Optional[Tuple[str]] = where + + self.files: List[FileObservation] = files + while len(self.files) < num_files_per_folder: + self.files.append(FileObservation()) + while len(self.files) > num_files_per_folder: + truncated_file = self.files.pop() + msg = f"Too many files in folder observation. Truncating file {truncated_file}" + _LOGGER.warning(msg) + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + ) -> "FolderObservation": + """Create folder observation from a config. Also creates child file observations. + + :param config: Dictionary containing the configuration for this folder observation. Includes the name of the + folder and the files inside of it. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this folder's + parent node. A typical location for a node ``where`` can be: + ['network','nodes',,'file_system'] + :type parent_where: Optional[List[str]] + :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed folder observation + :rtype: FolderObservation + """ + where = parent_where + ["folders", config["folder_name"]] + + file_configs = config["files"] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] + + return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py new file mode 100644 index 00000000..93c6765b --- /dev/null +++ b/src/primaite/game/agent/observations/node_observations.py @@ -0,0 +1,199 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.file_system_observations import FolderObservation +from primaite.game.agent.observations.observations import AbstractObservation, NicObservation +from primaite.game.agent.observations.software_observation import ServiceObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NodeObservation(AbstractObservation): + """Observation of a node in the network. Includes services, folders and NICs.""" + + def __init__( + self, + where: Optional[Tuple[str]] = None, + services: List[ServiceObservation] = [], + folders: List[FolderObservation] = [], + network_interfaces: List[NicObservation] = [], + logon_status: bool = False, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> None: + """ + Configurable observation for a node in the simulation. + + :param where: Where in the simulation state dictionary for find relevant information for this observation. + A typical location for a node looks like this: + ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] + :type where: List[str], optional + :param services: Mapping between position in observation space and service name, defaults to {} + :type services: Dict[int,str], optional + :param max_services: Max number of services that can be presented in observation space for this node + , defaults to 2 + :type max_services: int, optional + :param folders: Mapping between position in observation space and folder name, defaults to {} + :type folders: Dict[int,str], optional + :param max_folders: Max number of folders in this node's obs space, defaults to 2 + :type max_folders: int, optional + :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} + :type network_interfaces: Dict[int,str], optional + :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 + :type max_nics: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.services: List[ServiceObservation] = services + while len(self.services) < num_services_per_node: + # add empty service observation without `where` parameter so it always returns default (blank) observation + self.services.append(ServiceObservation()) + while len(self.services) > num_services_per_node: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + # truncate service list + + self.folders: List[FolderObservation] = folders + # add empty folder observation without `where` parameter that will always return default (blank) observations + while len(self.folders) < num_folders_per_node: + self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) + while len(self.folders) > num_folders_per_node: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" + _LOGGER.warning(msg) + + self.network_interfaces: List[NicObservation] = network_interfaces + while len(self.network_interfaces) < num_nics_per_node: + self.network_interfaces.append(NicObservation()) + while len(self.network_interfaces) > num_nics_per_node: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" + _LOGGER.warning(msg) + + self.logon_status: bool = logon_status + + self.default_observation: Dict = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + } + if self.logon_status: + self.default_observation["logon_status"] = 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NETWORK_INTERFACES"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + + if self.logon_status: + obs["logon_status"] = 0 + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + space_shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NETWORK_INTERFACES": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + } + if self.logon_status: + space_shape["logon_status"] = spaces.Discrete(3) + + return spaces.Dict(space_shape) + + @classmethod + def from_config( + cls, + config: Dict, + game: "PrimaiteGame", + parent_where: Optional[List[str]] = None, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> "NodeObservation": + """Create node observation from a config. Also creates child service, folder and NIC observations. + + :param config: Dictionary containing the configuration for this node observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this node's parent + network. A typical location for it would be: ['network',] + :type parent_where: Optional[List[str]] + :param num_services_per_node: How many spaces for services are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_services_per_node: int, optional + :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_folders_per_node: int, optional + :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed node observation + :rtype: NodeObservation + """ + node_hostname = config["node_hostname"] + if parent_where is None: + where = ["network", "nodes", node_hostname] + else: + where = parent_where + ["nodes", node_hostname] + + svc_configs = config.get("services", {}) + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] + folder_configs = config.get("folders", {}) + folders = [ + FolderObservation.from_config( + config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder + ) + for c in folder_configs + ] + # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. + nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] + network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] + logon_status = config.get("logon_status", False) + return cls( + where=where, + services=services, + folders=folders, + network_interfaces=network_interfaces, + logon_status=logon_status, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 6d6614f4..10e69ea5 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -46,128 +46,6 @@ class AbstractObservation(ABC): pass -class FileObservation(AbstractObservation): - """Observation of a file on a node in the network.""" - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """ - Initialise file observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',,'files',] - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.default_observation: spaces.Space = {"health_status": 0} - "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - file_state = access_from_nested_dict(state, self.where) - if file_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return {"health_status": file_state["visible_status"]} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"health_status": spaces.Discrete(6)}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": - """Create file observation from a config. - - :param config: Dictionary containing the configuration for this file observation. - :type config: Dict - :param game: _description_ - :type game: PrimaiteGame - :param parent_where: _description_, defaults to None - :type parent_where: _type_, optional - :return: _description_ - :rtype: _type_ - """ - return cls(where=parent_where + ["files", config["file_name"]]) - - -class ServiceObservation(AbstractObservation): - """Observation of a service in the network.""" - - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} - "Default observation is what should be returned when the service doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise service observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'services', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - service_state = access_from_nested_dict(state, self.where) - if service_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], - } - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ServiceObservation": - """Create service observation from a config. - - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation - :rtype: ServiceObservation - """ - return cls(where=parent_where + ["services", config["service_name"]]) - - class LinkObservation(AbstractObservation): """Observation of a link in the network.""" @@ -238,111 +116,6 @@ class LinkObservation(AbstractObservation): return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -class FolderObservation(AbstractObservation): - """Folder observation, including files inside of the folder.""" - - def __init__( - self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 - ) -> None: - """Initialise folder Observation, including files inside of the folder. - - :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',] - :type where: Optional[List[str]] - :param max_files: As size of the space must remain static, define max files that can be in this folder - , defaults to 5 - :type max_files: int, optional - :param file_positions: Defines the positioning within the observation space of particular files. This ensures - that even if new files are created, the existing files will always occupy the same space in the observation - space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the - observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same - name, it will take the position defined in this dict. Defaults to {} - :type file_positions: Dict[int, str], optional - """ - super().__init__() - - self.where: Optional[Tuple[str]] = where - - self.files: List[FileObservation] = files - while len(self.files) < num_files_per_folder: - self.files.append(FileObservation()) - while len(self.files) > num_files_per_folder: - truncated_file = self.files.pop() - msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warning(msg) - - self.default_observation = { - "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - folder_state = access_from_nested_dict(state, self.where) - if folder_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - health_status = folder_state["health_status"] - - obs = {} - - obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 - ) -> "FolderObservation": - """Create folder observation from a config. Also creates child file observations. - - :param config: Dictionary containing the configuration for this folder observation. Includes the name of the - folder and the files inside of it. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node ``where`` can be: - ['network','nodes',,'file_system'] - :type parent_where: Optional[List[str]] - :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed folder observation - :rtype: FolderObservation - """ - where = parent_where + ["folders", config["folder_name"]] - - file_configs = config["files"] - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] - - return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) - - class NicObservation(AbstractObservation): """Observation of a Network Interface Card (NIC) in the network.""" @@ -444,191 +217,6 @@ class NicObservation(AbstractObservation): return cls(where=parent_where + ["NICs", config["nic_num"]]) -class NodeObservation(AbstractObservation): - """Observation of a node in the network. Includes services, folders and NICs.""" - - def __init__( - self, - where: Optional[Tuple[str]] = None, - services: List[ServiceObservation] = [], - folders: List[FolderObservation] = [], - network_interfaces: List[NicObservation] = [], - logon_status: bool = False, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> None: - """ - Configurable observation for a node in the simulation. - - :param where: Where in the simulation state dictionary for find relevant information for this observation. - A typical location for a node looks like this: - ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] - :type where: List[str], optional - :param services: Mapping between position in observation space and service name, defaults to {} - :type services: Dict[int,str], optional - :param max_services: Max number of services that can be presented in observation space for this node - , defaults to 2 - :type max_services: int, optional - :param folders: Mapping between position in observation space and folder name, defaults to {} - :type folders: Dict[int,str], optional - :param max_folders: Max number of folders in this node's obs space, defaults to 2 - :type max_folders: int, optional - :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} - :type network_interfaces: Dict[int,str], optional - :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 - :type max_nics: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.services: List[ServiceObservation] = services - while len(self.services) < num_services_per_node: - # add empty service observation without `where` parameter so it always returns default (blank) observation - self.services.append(ServiceObservation()) - while len(self.services) > num_services_per_node: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - # truncate service list - - self.folders: List[FolderObservation] = folders - # add empty folder observation without `where` parameter that will always return default (blank) observations - while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) - while len(self.folders) > num_folders_per_node: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NicObservation] = network_interfaces - while len(self.network_interfaces) < num_nics_per_node: - self.network_interfaces.append(NicObservation()) - while len(self.network_interfaces) > num_nics_per_node: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warning(msg) - - self.logon_status: bool = logon_status - - self.default_observation: Dict = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, - } - if self.logon_status: - self.default_observation["logon_status"] = 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NETWORK_INTERFACES"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - - if self.logon_status: - obs["logon_status"] = 0 - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - space_shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NETWORK_INTERFACES": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - } - if self.logon_status: - space_shape["logon_status"] = spaces.Discrete(3) - - return spaces.Dict(space_shape) - - @classmethod - def from_config( - cls, - config: Dict, - game: "PrimaiteGame", - parent_where: Optional[List[str]] = None, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> "NodeObservation": - """Create node observation from a config. Also creates child service, folder and NIC observations. - - :param config: Dictionary containing the configuration for this node observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this node's parent - network. A typical location for it would be: ['network',] - :type parent_where: Optional[List[str]] - :param num_services_per_node: How many spaces for services are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_services_per_node: int, optional - :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_folders_per_node: int, optional - :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed node observation - :rtype: NodeObservation - """ - node_hostname = config["node_hostname"] - if parent_where is None: - where = ["network", "nodes", node_hostname] - else: - where = parent_where + ["nodes", node_hostname] - - svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] - folder_configs = config.get("folders", {}) - folders = [ - FolderObservation.from_config( - config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder - ) - for c in folder_configs - ] - # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. - nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] - network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] - logon_status = config.get("logon_status", False) - return cls( - where=where, - services=services, - folders=folders, - network_interfaces=network_interfaces, - logon_status=logon_status, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - - class AclObservation(AbstractObservation): """Observation of an Access Control List (ACL) in the network.""" diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py new file mode 100644 index 00000000..eae9dc1f --- /dev/null +++ b/src/primaite/game/agent/observations/software_observation.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ServiceObservation(AbstractObservation): + """Observation of a service in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} + "Default observation is what should be returned when the service doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise service observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'services', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": service_state["operating_state"], + "health_status": service_state["health_state_visible"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ServiceObservation": + """Create service observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ServiceObservation + """ + return cls(where=parent_where + ["services", config["service_name"]]) diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index b6aed30b..f52b52f7 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -1,6 +1,6 @@ from gymnasium import spaces -from primaite.game.agent.observations.observations import FileObservation +from primaite.game.agent.observations.file_system_observations import FileObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation