diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index c8095aa5..35468098 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -43,7 +43,7 @@ class AbstractAction(ABC): """Dictionary describing the number of options for each parameter of this action. The keys of this dict must align with the keyword args of the form_request method.""" self.manager: ActionManager = manager - """Reference to the ActionManager which created this action. This is used to access the session and simulation + """Reference to the ActionManager which created this action. This is used to access the game and simulation objects.""" @abstractmethod @@ -559,7 +559,7 @@ class ActionManager: def __init__( self, - session: "PrimaiteGame", # reference to session for looking up stuff + game: "PrimaiteGame", # reference to game for information lookup actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node max_folders_per_node: int = 2, # allows calculating shape @@ -574,8 +574,8 @@ class ActionManager: ) -> None: """Init method for ActionManager. - :param session: Reference to the session to which the agent belongs. - :type session: PrimaiteSession + :param game: Reference to the game to which the agent belongs. + :type game: PrimaiteGame :param actions: List of action types which should be made available to the agent. :type actions: List[str] :param node_uuids: List of node UUIDs that this agent can act on. @@ -599,8 +599,8 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.session: "PrimaiteGame" = session - self.sim: Simulation = self.session.simulation + self.game: "PrimaiteGame" = game + self.sim: Simulation = self.game.simulation self.node_uuids: List[str] = node_uuids self.protocols: List[str] = protocols self.ports: List[str] = ports @@ -826,7 +826,7 @@ class ActionManager: return nics[nic_idx] @classmethod - def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager": + def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": """ Construct an ActionManager from a config definition. @@ -845,20 +845,20 @@ class ActionManager: These options are used to calculate the shape of the action space, and to provide additional information to the ActionManager which is required to convert the agent's action choice into a CAOS request. - :param session: The Primaite Session to which the agent belongs. - :type session: PrimaiteSession + :param game: The Primaite Game to which the agent belongs. + :type game: PrimaiteGame :param cfg: The action space config. :type cfg: Dict :return: The constructed ActionManager. :rtype: ActionManager """ obj = cls( - session=session, + game=game, actions=cfg["action_list"], # node_uuids=cfg["options"]["node_uuids"], **cfg["options"], - protocols=session.options.protocols, - ports=session.options.ports, + protocols=game.options.protocols, + ports=game.options.ports, ip_address_list=None, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index f57ec10d..14fb2fa7 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -37,10 +37,10 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config: Dict, session: "PrimaiteGame"): + def from_config(cls, config: Dict, game: "PrimaiteGame"): """Create this observation space component form a serialised format. - The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, + The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation, a subclass of this class may need to translate from a 'reference' to a UUID. """ pass @@ -91,13 +91,13 @@ class FileObservation(AbstractObservation): return spaces.Dict({"health_status": spaces.Discrete(6)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": """Create file observation from a config. :param config: Dictionary containing the configuration for this file observation. :type config: Dict - :param session: _description_ - :type session: PrimaiteSession + :param game: _description_ + :type game: PrimaiteGame :param parent_where: _description_, defaults to None :type parent_where: _type_, optional :return: _description_ @@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None ) -> "ServiceObservation": """Create service observation from a config. :param config: Dictionary containing the configuration for this service observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. :type parent_where: Optional[List[str]], optional :return: Constructed service observation :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid]) + return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid]) class LinkObservation(AbstractObservation): @@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation): return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": """Create link observation from a config. :param config: Dictionary containing the configuration for this link observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Constructed link observation :rtype: LinkObservation """ - return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]]) + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) class FolderObservation(AbstractObservation): @@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 ) -> "FolderObservation": """Create folder observation from a config. Also creates child file observations. :param config: Dictionary containing the configuration for this folder observation. Includes the name of the folder and the files inside of it. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this folder's parent node. A typical location for a node ``where`` can be: ['network','nodes',,'file_system'] @@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation): where = parent_where + ["folders", config["folder_name"]] file_configs = config["files"] - files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) @@ -376,13 +376,13 @@ class NicObservation(AbstractObservation): return spaces.Dict({"nic_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": """Create NIC observation from a config. :param config: Dictionary containing the configuration for this NIC observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent node. A typical location for a node ``where`` can be: ['network','nodes',] :type parent_where: Optional[List[str]] @@ -513,7 +513,7 @@ class NodeObservation(AbstractObservation): def from_config( cls, config: Dict, - session: "PrimaiteGame", + game: "PrimaiteGame", parent_where: Optional[List[str]] = None, num_services_per_node: int = 2, num_folders_per_node: int = 2, @@ -524,8 +524,8 @@ class NodeObservation(AbstractObservation): :param config: Dictionary containing the configuration for this node observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this node's parent network. A typical location for it would be: ['network',] :type parent_where: Optional[List[str]] @@ -541,24 +541,24 @@ class NodeObservation(AbstractObservation): :return: Constructed node observation :rtype: NodeObservation """ - node_uuid = session.ref_map_nodes[config["node_ref"]] + node_uuid = game.ref_map_nodes[config["node_ref"]] if parent_where is None: where = ["network", "nodes", node_uuid] else: where = parent_where + ["nodes", node_uuid] svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs] + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] folder_configs = config.get("folders", {}) folders = [ FolderObservation.from_config( - config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder + config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder ) for c in folder_configs ] - nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys() + nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys() nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else [] - nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs] + nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] logon_status = config.get("logon_status", False) return cls( where=where, @@ -692,13 +692,13 @@ class AclObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": """Generate ACL observation from a config. :param config: Dictionary containing the configuration for this ACL observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Observation object :rtype: AclObservation """ @@ -707,15 +707,15 @@ class AclObservation(AbstractObservation): for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): node_ref = ip_map_config["node_ref"] nic_num = ip_map_config["nic_num"] - node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] nic_obj = node_obj.ethernet_port[nic_num] node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - router_uuid = session.ref_map_nodes[config["router_node_ref"]] + router_uuid = game.ref_map_nodes[config["router_node_ref"]] return cls( node_ip_to_id=node_ip_to_idx, - ports=session.options.ports, - protocols=session.options.protocols, + ports=game.options.ports, + protocols=game.options.protocols, where=["network", "nodes", router_uuid, "acl", "acl"], num_rules=max_acl_rules, ) @@ -738,7 +738,7 @@ class NullObservation(AbstractObservation): return spaces.Discrete(1) @classmethod - def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation": + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": """ Create null observation from a config. @@ -834,14 +834,14 @@ class UC2BlueObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": """Create UC2 blue observation from a config. :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, links, ACL and ICS observations. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Constructed UC2 blue observation :rtype: UC2BlueObservation """ @@ -853,7 +853,7 @@ class UC2BlueObservation(AbstractObservation): nodes = [ NodeObservation.from_config( config=n, - session=session, + game=game, num_services_per_node=num_services_per_node, num_folders_per_node=num_folders_per_node, num_files_per_folder=num_files_per_folder, @@ -863,13 +863,13 @@ class UC2BlueObservation(AbstractObservation): ] link_configs = config["links"] - links = [LinkObservation.from_config(config=link, session=session) for link in link_configs] + links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, session=session) + acl = AclObservation.from_config(config=acl_config, game=game) ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, session=session) + ics = ICSObservation.from_config(config=ics_config, game=game) new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) return new @@ -905,17 +905,17 @@ class UC2RedObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": """ Create UC2 red observation from a config. :param config: Dictionary containing the configuration for this UC2 red observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame """ node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs] + nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] return cls(nodes=nodes, where=["network"]) @@ -964,7 +964,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. @@ -972,14 +972,14 @@ class ObservationManager: UC2BlueObservation, UC2RedObservation, UC2GreenObservation) The other key is 'options' which are passed to the constructor of the selected observation class. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame """ if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) else: raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 60c3678c..8a1c2da4 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,13 +47,13 @@ class AbstractReward: @classmethod @abstractmethod - def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward": + def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: AbstractReward """ @@ -68,13 +68,13 @@ class DummyReward(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward": + def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor. Should be empty. :type config: dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame """ return cls() @@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward): return 0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: DatabaseFileIntegrity """ @@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward): f"{cls.__name__} could not be initialised from config because file_name parameter was not specified" ) return DummyReward() # TODO: better error handling - node_uuid = session.ref_map_nodes[node_ref] + node_uuid = game.ref_map_nodes[node_ref] if not node_uuid: _LOGGER.error( ( @@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: WebServer404Penalty """ @@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward): ) _LOGGER.warn(msg) return DummyReward() # TODO: should we error out with incorrect inputs? Probably! - node_uuid = session.ref_map_nodes[node_ref] - service_uuid = session.ref_map_services[service_ref].uuid + node_uuid = game.ref_map_nodes[node_ref] + service_uuid = game.ref_map_services[service_ref].uuid if not (node_uuid and service_uuid): msg = ( f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" @@ -265,13 +265,13 @@ class RewardFunction: return self.current_reward @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction": """Create a reward function from a config dictionary. :param config: dict of options for the reward manager's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward manager. :rtype: RewardFunction """ @@ -281,6 +281,6 @@ class RewardFunction: rew_type = rew_component_cfg["type"] weight = rew_component_cfg.get("weight", 1.0) rew_class = cls.__rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session) + rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game) new.regsiter_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e260285f..fa17b94b 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,4 +1,4 @@ -"""PrimAITE session - the main entry point to training agents on PrimAITE.""" +"""PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address from typing import Dict, List @@ -52,7 +52,7 @@ class PrimaiteGame: """ def __init__(self): - """Initialise a PrimaiteSession object.""" + """Initialise a PrimaiteGame object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" @@ -101,7 +101,7 @@ class PrimaiteGame: single-agent gym, make sure to update the ProxyAgent's action with the action before calling ``self.apply_agent_actions()``. """ - _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") + _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") # Get the current state of the simulation sim_state = self.get_sim_state() @@ -149,14 +149,14 @@ class PrimaiteGame: return False def reset(self) -> None: - """Reset the session, this will reset the simulation.""" + """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 - _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + _LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}") self.simulation.reset_component_for_episode(self.episode_counter) def close(self) -> None: - """Close the session, this will stop the env and close the simulation.""" + """Close the game, this will close the simulation.""" return NotImplemented @classmethod @@ -165,7 +165,7 @@ class PrimaiteGame: The config dictionary should have the following top-level keys: 1. training_config: options for training the RL agent. - 2. game_config: options for the game itself. Used by PrimaiteSession. + 2. game_config: options for the game itself. Used by PrimaiteGame. 3. simulation: defines the network topology and the initial state of the simulation. The specification for each of the three major areas is described in a separate documentation page. @@ -173,8 +173,8 @@ class PrimaiteGame: :param cfg: The config dictionary. :type cfg: dict - :return: A PrimaiteSession object. - :rtype: PrimaiteSession + :return: A PrimaiteGame object. + :rtype: PrimaiteGame """ game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) @@ -339,7 +339,7 @@ class PrimaiteGame: action_space = ActionManager.from_config(game, action_space_cfg) # CREATE REWARD FUNCTION - rew_function = RewardFunction.from_config(reward_function_cfg, session=game) + rew_function = RewardFunction.from_config(reward_function_cfg, game=game) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent":