Complete session->game rename refactor

This commit is contained in:
Marek Wolan
2023-11-24 09:14:55 +00:00
parent dd63563ba1
commit bd109a7cfc
4 changed files with 94 additions and 94 deletions

View File

@@ -43,7 +43,7 @@ class AbstractAction(ABC):
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
align with the keyword args of the form_request method."""
self.manager: ActionManager = manager
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
"""Reference to the ActionManager which created this action. This is used to access the game and simulation
objects."""
@abstractmethod
@@ -559,7 +559,7 @@ class ActionManager:
def __init__(
self,
session: "PrimaiteGame", # reference to session for looking up stuff
game: "PrimaiteGame", # reference to game for information lookup
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
max_folders_per_node: int = 2, # allows calculating shape
@@ -574,8 +574,8 @@ class ActionManager:
) -> None:
"""Init method for ActionManager.
:param session: Reference to the session to which the agent belongs.
:type session: PrimaiteSession
:param game: Reference to the game to which the agent belongs.
:type game: PrimaiteGame
:param actions: List of action types which should be made available to the agent.
:type actions: List[str]
:param node_uuids: List of node UUIDs that this agent can act on.
@@ -599,8 +599,8 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.session: "PrimaiteGame" = session
self.sim: Simulation = self.session.simulation
self.game: "PrimaiteGame" = game
self.sim: Simulation = self.game.simulation
self.node_uuids: List[str] = node_uuids
self.protocols: List[str] = protocols
self.ports: List[str] = ports
@@ -826,7 +826,7 @@ class ActionManager:
return nics[nic_idx]
@classmethod
def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager":
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config definition.
@@ -845,20 +845,20 @@ class ActionManager:
These options are used to calculate the shape of the action space, and to provide additional information
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
:param session: The Primaite Session to which the agent belongs.
:type session: PrimaiteSession
:param game: The Primaite Game to which the agent belongs.
:type game: PrimaiteGame
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
obj = cls(
session=session,
game=game,
actions=cfg["action_list"],
# node_uuids=cfg["options"]["node_uuids"],
**cfg["options"],
protocols=session.options.protocols,
ports=session.options.ports,
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=None,
act_map=cfg.get("action_map"),
)

View File

@@ -37,10 +37,10 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteGame"):
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
"""
pass
@@ -91,13 +91,13 @@ class FileObservation(AbstractObservation):
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param session: _description_
:type session: PrimaiteSession
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
@@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
@@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation):
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
@@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
@@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation):
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
@@ -376,13 +376,13 @@ class NicObservation(AbstractObservation):
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
:type parent_where: Optional[List[str]]
@@ -513,7 +513,7 @@ class NodeObservation(AbstractObservation):
def from_config(
cls,
config: Dict,
session: "PrimaiteGame",
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
@@ -524,8 +524,8 @@ class NodeObservation(AbstractObservation):
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
@@ -541,24 +541,24 @@ class NodeObservation(AbstractObservation):
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = session.ref_map_nodes[config["node_ref"]]
node_uuid = game.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
@@ -692,13 +692,13 @@ class AclObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
@@ -707,15 +707,15 @@ class AclObservation(AbstractObservation):
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
nic_num = ip_map_config["nic_num"]
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -738,7 +738,7 @@ class NullObservation(AbstractObservation):
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation":
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
@@ -834,14 +834,14 @@ class UC2BlueObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
@@ -853,7 +853,7 @@ class UC2BlueObservation(AbstractObservation):
nodes = [
NodeObservation.from_config(
config=n,
session=session,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
@@ -863,13 +863,13 @@ class UC2BlueObservation(AbstractObservation):
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, session=session)
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
@@ -905,17 +905,17 @@ class UC2RedObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
@@ -964,7 +964,7 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
@@ -972,14 +972,14 @@ class ObservationManager:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")

View File

@@ -47,13 +47,13 @@ class AbstractReward:
@classmethod
@abstractmethod
def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward":
def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: AbstractReward
"""
@@ -68,13 +68,13 @@ class DummyReward(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward":
def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
"""
return cls()
@@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward):
return 0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
@@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward):
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = session.ref_map_nodes[node_ref]
node_uuid = game.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
@@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: WebServer404Penalty
"""
@@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward):
)
_LOGGER.warn(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = session.ref_map_nodes[node_ref]
service_uuid = session.ref_map_services[service_ref].uuid
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref].uuid
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
@@ -265,13 +265,13 @@ class RewardFunction:
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction":
"""Create a reward function from a config dictionary.
:param config: dict of options for the reward manager's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward manager.
:rtype: RewardFunction
"""
@@ -281,6 +281,6 @@ class RewardFunction:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.__rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game)
new.regsiter_component(component=rew_instance, weight=weight)
return new

View File

@@ -1,4 +1,4 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List
@@ -52,7 +52,7 @@ class PrimaiteGame:
"""
def __init__(self):
"""Initialise a PrimaiteSession object."""
"""Initialise a PrimaiteGame object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
@@ -101,7 +101,7 @@ class PrimaiteGame:
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
@@ -149,14 +149,14 @@ class PrimaiteGame:
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
"""Reset the game, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
_LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
"""Close the game, this will close the simulation."""
return NotImplemented
@classmethod
@@ -165,7 +165,7 @@ class PrimaiteGame:
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
2. game_config: options for the game itself. Used by PrimaiteGame.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
@@ -173,8 +173,8 @@ class PrimaiteGame:
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
@@ -339,7 +339,7 @@ class PrimaiteGame:
action_space = ActionManager.from_config(game, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=game)
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":