diff --git a/example_config.yaml b/example_config.yaml new file mode 100644 index 00000000..6c02031a --- /dev/null +++ b/example_config.yaml @@ -0,0 +1,276 @@ +training_config: + rl_framework: SB3 + rl_algo: PPO + n_learn_steps: 128 + n_learn_episodes: 1000 + +game_config: + ports: + - ARP + - DNS + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + + agents: + - ref: client_1_green_user + team: GREEN + team: SCRIPTED_GREEN_ + observation_space: + ... + action_space: + ... + reward_function: + - type: null_reward + # node_ref: client_1 + # service: WebBrowser + # pol: + # - step: 1 + # action: START + + - ref: client_1_data_manipulation_red_bot + team: RED + type: SCRIPTED_RED_ + observation_space: + network: + nodes: + - ref: client_1 + - logon_status + - operating_status + services: + - ref: data_manipulation_bot + - operating_status + - health_status + folders: + files: {} + nics: {} + + action_space: + actions: + - DO_NOTHING + network: + nodes: + - ref: client_1 + actions: + - SCAN + - LOGON + - LOGOFF + services: + - ref: data_manipulation_bot + actions: + - type: COMPROMISE + execution_definition: + server_ip: 192.168.1.14 + payload: "DROP TABLE IF EXISTS user;" + success_rate: 80% + folders: + files: {} + reward_function: null + options: # options specific to this particular agent type, basically args of __init__(self) + start_step: 25 + frequency: 20 + variance: 5 + + + + + - ref: defender + team: blue + type: GATE_RL_AGENT + observation_space: + network: + nodes: + - ref: + action_space: + ... + reward_function: + ... + + + + + +simulation: + network: + nodes: + + - ref: router_1 + type: router + hostname: router_1 + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + acl: + 0: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 1: + action: PERMIT + src_port: DNS + dst_port: DNS + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - ref: switch_1 + type: swtich + hostname: switch_1 + num_ports: 8 + + - ref: switch_2 + type: switch + hostname: switch_2 + num_ports: 8 + + - ref: domain_controller + type: server + hostname: domain_controller + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - ref: domain_controller_dns_server + type: dns_server + options: + domain_mapping: + - arcd.com: 192.168.1.12 # web server + + + - ref: web_server + type: server + hostname: web_server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.10 + dns_server: 192.168.1.10 + services: + - ref: web_server_database_client + type: database_client + options: + db_server_ip: 192.168.1.14 + + - ref: database_server + type: server + hostname: database_server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: database_service + type: database_service + + + - ref: backup_server + type: node + hostname: backup_server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: backup_service + type: database_backup + + - ref: security_suite + type: server + hostname: security_suite + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + nics: + 2: + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + + - ref: client_1 + type: computer + hostname: client_1 + ip_address: 192.168.10.21. + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + services: + - ref: data_manipulation_bot + type: data_manipulation_bot + - ref: client_1_dns_client + type: dns_client + + - ref: client_2 + type: computer + hostname: client_2 + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + services: + - ref: web_browser + type: web_browser + - ref: client_2_dns_client + type: dns_client + + + links: + - ref: router_1___switch_1 + endpoint_a: router_1 + endpoint_a_port: 1 + endpoint_b: switch_1 + endpoint_b_port: 8 + - ref: router_1___switch_2 + endpoint_a: router_1 + endpoint_a_port: 2 + endpoint_b: switch_2 + endpoint_b_port: 8 + - ref: switch_1___domain_controller + endpoint_a: switch_1 + endpoint_a_port: 1 + endpoint_b: domain_controller + endpoint_b_port: 1 + - ref: switch_1___web_server + endpoint_a: switch_1 + endpoint_a_port: 2 + endpoint_b: web_server + endpoint_b_port: 1 + - ref: switch_1___database_server + endpoint_a: switch_1 + endpoint_a_port: 3 + endpoint_b: database_server + endpoint_b_port: 1 + - ref: switch_1___backup_server + endpoint_a: switch_1 + endpoint_a_port: 4 + endpoint_b: backup_server + endpoint_b_port: 1 + - ref: switch_1___security_suite + endpoint_a: switch_1 + endpoint_a_port: 7 + endpoint_b: security_suite + endpoint_b_port: 1 + - ref: switch_2___client_1 + endpoint_a: switch_2 + endpoint_a_port: 1 + endpoint_b: client_1 + endpoint_b_port: 1 + - ref: switch_2___client_2 + endpoint_a: switch_2 + endpoint_a_port: 2 + endpoint_b: client_2 + endpoint_b_port: 1 + - ref: switch_2___security_suite + endpoint_a: switch_2 + endpoint_a_port: 7 + endpoint_b: security_suite + endpoint_b_port: 2 diff --git a/src/primaite/game/actor/interface.py b/src/primaite/game/actor/interface.py index 1fe43a32..d1245e71 100644 --- a/src/primaite/game/actor/interface.py +++ b/src/primaite/game/actor/interface.py @@ -11,10 +11,13 @@ from primaite.game.actor.observations import ObservationSpace from primaite.game.actor.rewards import RewardFunction -class AbstractActor(BaseModel): +class AbstractActor(ABC): """Base class for scripted and RL agents.""" - ... + def __init__(self) -> None: + self.action_space = ActionSpace + self.observation_space = ObservationSpace + self.reward_function = RewardFunction class AbstractScriptedActor(AbstractActor): diff --git a/src/primaite/game/actor/observations.py b/src/primaite/game/actor/observations.py index 7303b07b..4d4796e1 100644 --- a/src/primaite/game/actor/observations.py +++ b/src/primaite/game/actor/observations.py @@ -1,9 +1,16 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Hashable, List - -from pydantic import BaseModel +from typing import Any, Dict, Hashable, List, Optional from gym import spaces +from pydantic import BaseModel + +from primaite.simulator.sim_container import Simulation + +NOT_PRESENT_IN_STATE = object() +""" +Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes +the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose. +""" def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any: @@ -20,19 +27,17 @@ def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any: :return: The value in the dictionary :rtype: Any """ - if not keys: + if len(keys) == 0: return dictionary k = keys.pop(0) - try: - return access_from_nested_dict(dictionary[k], keys) - except (TypeError, KeyError): - raise KeyError(f"Cannot find requested key `{k}` in nested dictionary") + if k not in dictionary: + return NOT_PRESENT_IN_STATE + return access_from_nested_dict(dictionary[k], keys) -class AbstractObservation(BaseModel): - +class AbstractObservation(ABC): @abstractmethod - def __call__(self, state: Dict) -> Any: + def observe(self, state: Dict) -> Any: """_summary_ :param state: _description_ @@ -41,7 +46,6 @@ class AbstractObservation(BaseModel): :rtype: Any """ ... - # receive state dict @property @abstractmethod @@ -51,72 +55,396 @@ class AbstractObservation(BaseModel): class FileObservation(AbstractObservation): - where: List[str] - """Store information about where in the simulation state dictionary to find the relevatn information.""" + def __init__(self, where: List[str] = []) -> None: + """ + _summary_ - def __call__(self, state: Dict) -> Dict: + :param where: Store information about where in the simulation state dictionary to find the relevatn information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',,'files',] + :type where: Optional[List[str]] + """ + super().__init__() + self.where: List[str] = where + self.default_observation: spaces.Space = {"health_status": 0} + "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation file_state = access_from_nested_dict(state, self.where) - observation = {'health_status':file_state['health_status']} - return observation + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"health_status": file_state["health_status"]} @property def space(self) -> spaces.Space: - return spaces.Dict({'health_status':spaces.Discrete(6)}) + return spaces.Dict({"health_status": spaces.Discrete(6)}) + + +class ServiceObservation(AbstractObservation): + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} + "Default observation is what should be returned when the service doesn't exist." + + def __init__(self, where: List[str] = []) -> None: + """ + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'servics', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: List[str] = where + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation + + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"operating_status": service_state["operating_status"], "health_status": service_state["health_status"]} + + @property + def space(self) -> spaces.Space: + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) + + +class LinkObservation(AbstractObservation): + default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}} + "Default observation is what should be returned when the link doesn't exist." + + def __init__(self, where: List[str] = []) -> None: + """ + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'servics', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: List[str] = where + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation + + link_state = access_from_nested_dict(state, self.where) + if link_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + bandwidth = link_state["bandwidth"] + load = link_state["current_load"] + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 10) + 1 + + # TODO: once the links support separte load per protocol, this needs amendment to reflect that. + return {"protocols": {"all": {"load": utilisation_category}}} + + @property + def space(self) -> spaces.Space: + return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})}) + + +class FolderObservation(AbstractObservation): + def __init__(self, where: List[str] = [], files: List[FileObservation] = []) -> None: + """Initialise folder Observation, including files inside of the folder. + + :param where: Where in the simulation state dictionary to find the relevant information for this folder. + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',] + :type where: Optional[List[str]] + :param max_files: As size of the space must remain static, define max files that can be in this folder + , defaults to 5 + :type max_files: int, optional + :param file_positions: Defines the positioning within the observation space of particular files. This ensures + that even if new files are created, the existing files will always occupy the same space in the observation + space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the + observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same + name, it will take the position defined in this dict. Defaults to {} + :type file_positions: Dict[int, str], optional + """ + super().__init__() + + self.where: List[str] = where + + self.files: List[FileObservation] = files + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + + +class NicObservation(AbstractObservation): + default_observation: spaces.Space = {"nic_status": 0} + + def __init__(self, where: List[str] = []) -> None: + super.__init__() + self.where: List[str] = where + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation + nic_state = access_from_nested_dict(state, self.where) + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + else: + return {"nic_status": 1 if nic_state["enabled"] else 2} + + @property + def space(self) -> spaces.Space: + return spaces.Dict({"nic_status": spaces.Discrete(3)}) + + +class NodeObservation(AbstractObservation): + def __init__( + self, + where: List[str] = [], + services: List[ServiceObservation] = [], + folders: List[FolderObservation] = [], + nics: List[NicObservation] = [], + ) -> None: + """ + Configurable observation for a node in the simulation. + + :param where: Where in the simulation state dictionary for find relevant information for this observation. + A typical location for a node looks like this: + ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] + :type where: List[str], optional + :param services: Mapping between position in observation space and service UUID, defaults to {} + :type services: Dict[int,str], optional + :param max_services: Max number of services that can be presented in observation space for this node, defaults to 2 + :type max_services: int, optional + :param folders: Mapping between position in observation space and folder name, defaults to {} + :type folders: Dict[int,str], optional + :param max_folders: Max number of folders in this node's obs space, defaults to 2 + :type max_folders: int, optional + :param nics: Mapping between position in observation space and NIC UUID, defaults to {} + :type nics: Dict[int,str], optional + :param max_nics: Max number of NICS in this node's obs space, defaults to 5 + :type max_nics: int, optional + """ + super.__init__() + self.where: List[str] = where + + self.services: List[ServiceObservation] = services + self.folders: List[FolderObservation] = folders + self.nics: List[NicObservation] = nics + + self.default_observation: Dict = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)}, + "operating_status": 0, + } + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation + + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} + + return obs + + @property + def space(self) -> spaces.Space: + return spaces.Dict( + { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(0), + "NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}), + } + ) + + +class AclObservation(AbstractObservation): + # TODO: should where be optional, and we can use where=None to pad the observation space? + # definitely the current approach does not support tracking files that aren't specified by name, for example + # if a file is created at runtime, we have currently got no way of telling the observation space to track it. + # this needs adding, but not for the MVP. + def __init__( + self, nodes: List[str], ports: List[int], protocols: list[str], where: List[str] = [], num_rules: int = 10 + ) -> None: + super().__init__() + self.where: List[str] = where + self.num_rules: int = num_rules + self.node_to_id: Dict[str, int] = {node: i + 1 for i, node in enumerate(nodes)} + "List of node IP addresses, order in this list determines how they are converted to an ID" + self.port_to_id: Dict[int, int] = {port: i + 1 for i, port in enumerate(ports)} + "List of ports which are part of the game that define the ordering when converting to an ID" + self.protocol_to_id: Dict[str, int] = {protocol: i + 1 for i, protocol in enumerate(protocols)} + "List of protocols which are part of the game, defines ordering when converting to an ID" + self.default_observation: spaces.Space = spaces.Dict( + { + "RULES": spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + ) + for i in range(self.num_rules) + } + ) + } + ) + + def observe(self, state: Dict) -> Dict: + if not self.where: + return self.default_observation + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["RULES"] = {} + for i, rule_state in acl_state.items(): + if rule_state is None: + obs["RULES"][i + 1] = { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + obs["RULES"][i + 1] = { + "position": i, + "permission": rule_state["action"], + "source_node_id": self.node_to_id[rule_state["src_ip_address"]], + "source_port": self.port_to_id[rule_state["src_port"]], + "dest_node_id": self.node_to_id[rule_state["dst_ip_address"]], + "dest_port": self.port_to_id[rule_state["dst_port"]], + "protocol": self.protocol_to_id[rule_state["protocol"]], + } + return obs + + @property + def space(self) -> spaces.Space: + return spaces.Dict( + { + "RULE": spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + "source_node_id": spaces.Discrete(len(self.nodes) + 1), + "source_port": spaces.Discrete(len(self.ports) + 1), + "dest_node_id": spaces.Discrete(len(self.nodes) + 1), + "dest_port": spaces.Discrete(len(self.ports) + 1), + "protocol": spaces.Discrete(len(self.protocols) + 1), + } + ) + for i in range(self.num_rules) + } + ) + } + ) + + +class ICSObservation(AbstractObservation): + def observe(self, state: Dict) -> Any: + return 0 + + @property + def space(self) -> spaces.Space: + return spaces.Discrete(1) class ObservationSpace: - """Manage the observations of an Actor.""" + """ + Manage the observations of an Actor. + + The observation space has the purpose of: + 1. Reading the outputted state from the PrimAITE Simulation. + 2. Selecting parts of the simulation state that are requested by the simulation config + 3. Formatting this information so an actor can use it to make decisions. + """ ... + # what this class does: # keep a list of observations # create observations for an actor from the config + def __init__( + self, + simulation: Simulation, + nodes: List[NodeObservation] = [], + links: List[LinkObservation] = [], + acl: Optional[AclObservation] = None, + ics: Optional[ICSObservation] = None, + ) -> None: + self.simulation: Simulation = simulation + self.parts: Dict[str, AbstractObservation] = {} + self.nodes: List[NodeObservation] = nodes + self.links: List[LinkObservation] = links + self.acl: Optional[AclObservation] = acl + self.ics: Optional[ICSObservation] = ics -# Example YAML file for agent observation space -""" -arcd_gate: - rl_framework: SB3 - rl_algo: PPO - n_learn_steps: 128 - n_learn_episodes: 1000 + def observe(self) -> None: + ... -game_layer: - agents: - - ref: client_1_green_user - type: GREEN - node_ref: client_1 - service: WebBrowser - pol: - - step: 1 - action: START + @property + def space(self) -> None: + ... - - ref: client_1_data_manip_red_bot - node_ref: client_1 - service: DataManipulationBot - execution_definition: - - server_ip_address: 192.168.1.10 - - server_password: - - payload: 'ATTACK' - - pol: - - step: 75 - action: EXECUTE - - - - -simulation: - nodes: - - ref: client_1 - hostname: client_1 - node_type: Computer - ip_address: 192.168.10.100 - services: - - name: DataManipulationBot - links: - endpoint_a: - endpoint_b: 1524552-fgfg4147gdh-25gh4gd -rewards: - -""" + @classmethod + def from_config(self) -> None: + ... diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index a14d6a6d..f5ba8444 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -855,14 +855,14 @@ class ICMP: class NodeOperatingState(Enum): """Enumeration of Node Operating States.""" - OFF = 0 - "The node is powered off." ON = 1 "The node is powered on." - SHUTTING_DOWN = 2 - "The node is in the process of shutting down." + OFF = 2 + "The node is powered off." BOOTING = 3 "The node is in the process of booting up." + SHUTTING_DOWN = 4 + "The node is in the process of shutting down." class Node(SimComponent): diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 53b9b176..7870caab 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -58,7 +58,14 @@ class ACLRule(SimComponent): :return: A dictionary representing the current state. """ - pass + state = super().describe_state() + state["action"] = self.action.value + state["protocol"] = self.protocol.value + state["src_ip_address"] = self.src_ip_address + state["src_port"] = self.src_port.value + state["dst_ip_address"] = self.dst_ip_address + state["dst_port"] = self.dst_port.value + return state class AccessControlList(SimComponent): @@ -123,7 +130,12 @@ class AccessControlList(SimComponent): :return: A dictionary representing the current state. """ - pass + state = super().describe_state() + state["implicit_action"] = self.implicit_action.value + state["implicit_rule"] = self.implicit_rule.describe_state() + state["max_acl_rules"] = self.max_acl_rules + state["acl"] = {i: r.describe_state() if isinstance(r, ACLRule) else None for i, r in enumerate(self._acl)} + return state @property def acl(self) -> List[Optional[ACLRule]]: @@ -648,7 +660,10 @@ class Router(Node): :return: A dictionary representing the current state. """ - pass + state = super().describe_state() + state["num_ports"] = (self.num_ports,) + state["acl"] = (self.acl.describe_state(),) + return state def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: """ diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 20b92027..fb12fc3d 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -15,14 +15,14 @@ class ServiceOperatingState(Enum): "The service is currently running." STOPPED = 2 "The service is not running." - INSTALLING = 3 - "The service is being installed or updated." - RESTARTING = 4 - "The service is in the process of restarting." - PAUSED = 5 + PAUSED = 3 "The service is temporarily paused." - DISABLED = 6 + DISABLED = 4 "The service is disabled and cannot be started." + INSTALLING = 5 + "The service is being installed or updated." + RESTARTING = 6 + "The service is in the process of restarting." class Service(IOSoftware): @@ -60,7 +60,7 @@ class Service(IOSoftware): :rtype: Dict """ state = super().describe_state() - state.update({"operating_state": self.operating_state.name}) + state.update({"operating_state": self.operating_state.value}) return state def reset_component_for_episode(self, episode: int): diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py new file mode 100644 index 00000000..4f778f78 --- /dev/null +++ b/tests/integration_tests/game_layer/test_observations.py @@ -0,0 +1,20 @@ +from gym import spaces + +from primaite.game.actor.observations import FileObservation +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.sim_container import Simulation + + +def test_file_observation(): + sim = Simulation() + pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0") + sim.network.add_node(pc) + f = pc.file_system.create_file(file_name="dog.png") + + state = sim.describe_state() + + dog_file_obs = FileObservation( + where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"] + ) + assert dog_file_obs(state) == {"health_status": 1} + assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})