Complete session->game rename refactor
This commit is contained in:
@@ -43,7 +43,7 @@ class AbstractAction(ABC):
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
|
||||
"""Reference to the ActionManager which created this action. This is used to access the game and simulation
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -559,7 +559,7 @@ class ActionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: "PrimaiteGame", # reference to session for looking up stuff
|
||||
game: "PrimaiteGame", # reference to game for information lookup
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
@@ -574,8 +574,8 @@ class ActionManager:
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param session: Reference to the session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param actions: List of action types which should be made available to the agent.
|
||||
:type actions: List[str]
|
||||
:param node_uuids: List of node UUIDs that this agent can act on.
|
||||
@@ -599,8 +599,8 @@ class ActionManager:
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.session: "PrimaiteGame" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.sim: Simulation = self.game.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
@@ -826,7 +826,7 @@ class ActionManager:
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
@@ -845,20 +845,20 @@ class ActionManager:
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param session: The Primaite Session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param game: The Primaite Game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
obj = cls(
|
||||
session=session,
|
||||
game=game,
|
||||
actions=cfg["action_list"],
|
||||
# node_uuids=cfg["options"]["node_uuids"],
|
||||
**cfg["options"],
|
||||
protocols=session.options.protocols,
|
||||
ports=session.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
@@ -37,10 +37,10 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame"):
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
|
||||
a subclass of this class may need to translate from a 'reference' to a UUID.
|
||||
"""
|
||||
pass
|
||||
@@ -91,13 +91,13 @@ class FileObservation(AbstractObservation):
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param session: _description_
|
||||
:type session: PrimaiteSession
|
||||
:param game: _description_
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
@@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
|
||||
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
@@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation):
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
@@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_uuid>,'file_system']
|
||||
@@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation):
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
|
||||
@@ -376,13 +376,13 @@ class NicObservation(AbstractObservation):
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
@@ -513,7 +513,7 @@ class NodeObservation(AbstractObservation):
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
session: "PrimaiteGame",
|
||||
game: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
@@ -524,8 +524,8 @@ class NodeObservation(AbstractObservation):
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
@@ -541,24 +541,24 @@ class NodeObservation(AbstractObservation):
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_uuid = session.ref_map_nodes[config["node_ref"]]
|
||||
node_uuid = game.ref_map_nodes[config["node_ref"]]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_uuid]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_uuid]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
|
||||
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
|
||||
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
@@ -692,13 +692,13 @@ class AclObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
@@ -707,15 +707,15 @@ class AclObservation(AbstractObservation):
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_ref"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.ethernet_port[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
|
||||
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
@@ -738,7 +738,7 @@ class NullObservation(AbstractObservation):
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
@@ -834,14 +834,14 @@ class UC2BlueObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
@@ -853,7 +853,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
session=session,
|
||||
game=game,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
@@ -863,13 +863,13 @@ class UC2BlueObservation(AbstractObservation):
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
|
||||
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, session=session)
|
||||
acl = AclObservation.from_config(config=acl_config, game=game)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, session=session)
|
||||
ics = ICSObservation.from_config(config=ics_config, game=game)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
@@ -905,17 +905,17 @@ class UC2RedObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
|
||||
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
@@ -964,7 +964,7 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
@@ -972,14 +972,14 @@ class ObservationManager:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
|
||||
@@ -47,13 +47,13 @@ class AbstractReward:
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward":
|
||||
def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
@@ -68,13 +68,13 @@ class DummyReward(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward":
|
||||
def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
return cls()
|
||||
|
||||
@@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
@@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(
|
||||
(
|
||||
@@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
@@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward):
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
service_uuid = session.ref_map_services[service_ref].uuid
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
service_uuid = game.ref_map_services[service_ref].uuid
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
|
||||
@@ -265,13 +265,13 @@ class RewardFunction:
|
||||
return self.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward manager.
|
||||
:rtype: RewardFunction
|
||||
"""
|
||||
@@ -281,6 +281,6 @@ class RewardFunction:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.__rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game)
|
||||
new.regsiter_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -52,7 +52,7 @@ class PrimaiteGame:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise a PrimaiteSession object."""
|
||||
"""Initialise a PrimaiteGame object."""
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
@@ -101,7 +101,7 @@ class PrimaiteGame:
|
||||
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
|
||||
``self.apply_agent_actions()``.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
@@ -149,14 +149,14 @@ class PrimaiteGame:
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
"""Reset the game, this will reset the simulation."""
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
|
||||
_LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(self.episode_counter)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
"""Close the game, this will close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
@@ -165,7 +165,7 @@ class PrimaiteGame:
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
2. game_config: options for the game itself. Used by PrimaiteGame.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
@@ -173,8 +173,8 @@ class PrimaiteGame:
|
||||
|
||||
:param cfg: The config dictionary.
|
||||
:type cfg: dict
|
||||
:return: A PrimaiteSession object.
|
||||
:rtype: PrimaiteSession
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
@@ -339,7 +339,7 @@ class PrimaiteGame:
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=game)
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
|
||||
Reference in New Issue
Block a user