diff --git a/example_config.yaml b/example_config.yaml index 00beaa1e..bcc819ae 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -2,10 +2,10 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 4 - n_learn_steps: 128 - n_eval_episodes: 1 - n_eval_steps: 128 + n_learn_episodes: 1 + n_learn_steps: 8 + n_eval_episodes: 0 + n_eval_steps: 8 game_config: @@ -39,10 +39,10 @@ game_config: options: nodes: - node_ref: client_2 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 - max_nics_per_node: 8 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 max_acl_rules: 10 reward_function: @@ -93,9 +93,9 @@ game_config: options: nodes: - node_ref: client_1 - max_folders_per_node: 2 - max_files_per_folder: 2 - max_services_per_node: 2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 reward_function: reward_components: @@ -113,9 +113,10 @@ game_config: observation_space: type: UC2BlueObservation options: - num_services_per_node: 2 - num_folders_per_node: 2 - num_files_per_folder: 2 + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 nodes: - node_ref: domain_controller services: @@ -148,6 +149,8 @@ game_config: - link_ref: switch_2___client_2 - link_ref: switch_2___security_suite acl: + options: + max_acl_rules: 10 router_node_ref: router_1 ip_address_order: - node_ref: domain_controller diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 35fe8ac5..8eb322bd 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -167,7 +167,7 @@ class ServiceObservation(AbstractObservation): class LinkObservation(AbstractObservation): """Observation of a link in the network.""" - default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}} + default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} "Default observation is what should be returned when the link doesn't exist." def __init__(self, where: Optional[Tuple[str]] = None) -> None: @@ -206,7 +206,7 @@ class LinkObservation(AbstractObservation): utilisation_category = int(utilisation_fraction * 10) + 1 # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"protocols": {"all": {"load": utilisation_category}}} + return {"PROTOCOLS": {"ALL": utilisation_category}} @property def space(self) -> spaces.Space: @@ -215,7 +215,7 @@ class LinkObservation(AbstractObservation): :return: Gymnasium space :rtype: spaces.Space """ - return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})}) + return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation": @@ -264,7 +264,6 @@ class FolderObservation(AbstractObservation): truncated_file = self.files.pop() msg = f"Too many files in folde observation. Truncating file {truncated_file}" _LOGGER.warn(msg) - raise UserWarning(msg) self.default_observation = { "health_status": 0, @@ -407,6 +406,7 @@ class NodeObservation(AbstractObservation): num_services_per_node: int = 2, num_folders_per_node: int = 2, num_files_per_folder: int = 2, + num_nics_per_node: int = 2, ) -> None: """ Configurable observation for a node in the simulation. @@ -440,18 +440,25 @@ class NodeObservation(AbstractObservation): truncated_service = self.services.pop() msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" _LOGGER.warn(msg) - raise UserWarning(msg) # truncate service list self.folders: List[FolderObservation] = folders # add empty folder observation without `where` parameter that will always return default (blank) observations while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation()) + self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) while len(self.folders) > num_folders_per_node: truncated_folder = self.folders.pop() msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" + _LOGGER.warn(msg) self.nics: List[NicObservation] = nics + while len(self.nics) < num_nics_per_node: + self.nics.append(NicObservation()) + while len(self.nics) > num_nics_per_node: + truncated_nic = self.nics.pop() + msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" + _LOGGER.warn(msg) + self.logon_status: bool = logon_status self.default_observation: Dict = { @@ -512,6 +519,7 @@ class NodeObservation(AbstractObservation): num_services_per_node: int = 2, num_folders_per_node: int = 2, num_files_per_folder: int = 2, + num_nics_per_node: int = 2, ) -> "NodeObservation": """Create node observation from a config. Also creates child service, folder and NIC observations. @@ -562,6 +570,7 @@ class NodeObservation(AbstractObservation): num_services_per_node=num_services_per_node, num_folders_per_node=num_folders_per_node, num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, ) @@ -605,19 +614,17 @@ class AclObservation(AbstractObservation): self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} "List of protocols which are part of the game, defines ordering when converting to an ID" self.default_observation: Dict = { - "RULES": { - i - + 1: { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - for i in range(self.num_rules) + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, } + for i in range(self.num_rules) } def observe(self, state: Dict) -> Dict: @@ -636,10 +643,9 @@ class AclObservation(AbstractObservation): # TODO: what if the ACL has more rules than num of max rules for obs space obs = {} - obs["RULES"] = {} for i, rule_state in acl_state.items(): if rule_state is None: - obs["RULES"][i + 1] = { + obs[i + 1] = { "position": i, "permission": 0, "source_node_id": 0, @@ -649,7 +655,7 @@ class AclObservation(AbstractObservation): "protocol": 0, } else: - obs["RULES"][i + 1] = { + obs[i + 1] = { "position": i, "permission": rule_state["action"], "source_node_id": self.node_to_id[rule_state["src_ip_address"]], @@ -669,24 +675,20 @@ class AclObservation(AbstractObservation): """ return spaces.Dict( { - "RULES": spaces.Dict( + i + + 1: spaces.Dict( { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), } ) + for i in range(self.num_rules) } ) @@ -701,6 +703,7 @@ class AclObservation(AbstractObservation): :return: Observation object :rtype: AclObservation """ + max_acl_rules = config["options"]["max_acl_rules"] node_ip_to_idx = {} for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): node_ref = ip_map_config["node_ref"] @@ -715,6 +718,7 @@ class AclObservation(AbstractObservation): ports=session.options.ports, protocols=session.options.protocols, where=["network", "nodes", router_uuid, "acl", "acl"], + num_rules=max_acl_rules, ) @@ -846,6 +850,7 @@ class UC2BlueObservation(AbstractObservation): num_services_per_node = config["num_services_per_node"] num_folders_per_node = config["num_folders_per_node"] num_files_per_folder = config["num_files_per_folder"] + num_nics_per_node = config["num_nics_per_node"] nodes = [ NodeObservation.from_config( config=n, @@ -853,6 +858,7 @@ class UC2BlueObservation(AbstractObservation): num_services_per_node=num_services_per_node, num_folders_per_node=num_folders_per_node, num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, ) for n in node_configs ]