Merge branch 'dev' into feature/2041_2042-Add-NTP-Services
This commit is contained in:
File diff suppressed because it is too large
Load Diff
1164
src/primaite/config/_package_data/example_config_2_rl_agents.yaml
Normal file
1164
src/primaite/config/_package_data/example_config_2_rl_agents.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
@@ -43,7 +43,7 @@ class AbstractAction(ABC):
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
|
||||
"""Reference to the ActionManager which created this action. This is used to access the game and simulation
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -559,7 +559,7 @@ class ActionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: "PrimaiteSession", # reference to session for looking up stuff
|
||||
game: "PrimaiteGame", # reference to game for information lookup
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
@@ -574,8 +574,8 @@ class ActionManager:
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param session: Reference to the session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param actions: List of action types which should be made available to the agent.
|
||||
:type actions: List[str]
|
||||
:param node_uuids: List of node UUIDs that this agent can act on.
|
||||
@@ -599,8 +599,8 @@ class ActionManager:
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.sim: Simulation = self.game.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
@@ -826,7 +826,7 @@ class ActionManager:
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
@@ -845,20 +845,20 @@ class ActionManager:
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param session: The Primaite Session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param game: The Primaite Game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
obj = cls(
|
||||
session=session,
|
||||
game=game,
|
||||
actions=cfg["action_list"],
|
||||
# node_uuids=cfg["options"]["node_uuids"],
|
||||
**cfg["options"],
|
||||
protocols=session.options.protocols,
|
||||
ports=session.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
@@ -37,10 +37,10 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
|
||||
a subclass of this class may need to translate from a 'reference' to a UUID.
|
||||
"""
|
||||
pass
|
||||
@@ -91,13 +91,13 @@ class FileObservation(AbstractObservation):
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param session: _description_
|
||||
:type session: PrimaiteSession
|
||||
:param game: _description_
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
@@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
|
||||
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
@@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation):
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
@@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_uuid>,'file_system']
|
||||
@@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation):
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
|
||||
@@ -376,15 +376,13 @@ class NicObservation(AbstractObservation):
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
|
||||
) -> "NicObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
@@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation):
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
session: "PrimaiteSession",
|
||||
game: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
@@ -526,8 +524,8 @@ class NodeObservation(AbstractObservation):
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
@@ -543,24 +541,24 @@ class NodeObservation(AbstractObservation):
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_uuid = session.ref_map_nodes[config["node_ref"]]
|
||||
node_uuid = game.ref_map_nodes[config["node_ref"]]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_uuid]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_uuid]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
|
||||
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
|
||||
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
@@ -694,13 +692,13 @@ class AclObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
@@ -709,15 +707,15 @@ class AclObservation(AbstractObservation):
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_ref"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.ethernet_port[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
|
||||
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
@@ -740,7 +738,7 @@ class NullObservation(AbstractObservation):
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
@@ -836,14 +834,14 @@ class UC2BlueObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
@@ -855,7 +853,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
session=session,
|
||||
game=game,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
@@ -865,13 +863,13 @@ class UC2BlueObservation(AbstractObservation):
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
|
||||
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, session=session)
|
||||
acl = AclObservation.from_config(config=acl_config, game=game)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, session=session)
|
||||
ics = ICSObservation.from_config(config=ics_config, game=game)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
@@ -907,17 +905,17 @@ class UC2RedObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
|
||||
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
@@ -966,7 +964,7 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
@@ -974,14 +972,14 @@ class ObservationManager:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
|
||||
@@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
@@ -47,13 +47,13 @@ class AbstractReward:
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
|
||||
def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
@@ -68,13 +68,13 @@ class DummyReward(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
|
||||
def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
return cls()
|
||||
|
||||
@@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
@@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(
|
||||
(
|
||||
@@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
@@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward):
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
service_uuid = session.ref_map_services[service_ref].uuid
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
service_uuid = game.ref_map_services[service_ref].uuid
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
|
||||
@@ -265,13 +265,13 @@ class RewardFunction:
|
||||
return self.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:param game: Reference to the PrimAITE Game object
|
||||
:type game: PrimaiteGame
|
||||
:return: The reward manager.
|
||||
:rtype: RewardFunction
|
||||
"""
|
||||
@@ -281,6 +281,6 @@ class RewardFunction:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.__rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game)
|
||||
new.regsiter_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from enum import Enum
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from copy import deepcopy
|
||||
from ipaddress import IPv4Address
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -13,8 +10,6 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.io import SessionIO, SessionIOSettings
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -36,65 +31,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.session.apply_agent_actions()
|
||||
self.session.advance_timestep()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.session.calculate_truncated()
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
class PrimaiteGameOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
|
||||
@@ -103,43 +40,26 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
max_episode_length: int = 256
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
class PrimaiteGame:
|
||||
"""
|
||||
Primaite game encapsulates the simulation and agents which interact with it.
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
|
||||
Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise a PrimaiteSession object."""
|
||||
"""Initialise a PrimaiteGame object."""
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self._simulation_initial_state = deepcopy(self.simulation)
|
||||
"""The Simulation original state (deepcopy of the original Simulation)."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
|
||||
@@ -152,15 +72,9 @@ class PrimaiteSession:
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteSessionOptions
|
||||
self.options: PrimaiteGameOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
|
||||
@@ -173,40 +87,6 @@ class PrimaiteSession:
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -225,7 +105,7 @@ class PrimaiteSession:
|
||||
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
|
||||
``self.apply_agent_actions()``.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
@@ -267,29 +147,29 @@ class PrimaiteSession:
|
||||
def calculate_truncated(self) -> bool:
|
||||
"""Calculate whether the episode is truncated."""
|
||||
current_step = self.step_counter
|
||||
max_steps = self.training_options.max_steps_per_episode
|
||||
max_steps = self.options.max_episode_length
|
||||
if current_step >= max_steps:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
"""Reset the game, this will reset the simulation."""
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(self.episode_counter)
|
||||
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation = deepcopy(self._simulation_initial_state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
"""Close the game, this will close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
|
||||
"""Create a PrimaiteGame object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
2. game_config: options for the game itself. Used by PrimaiteGame.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
@@ -297,26 +177,19 @@ class PrimaiteSession:
|
||||
|
||||
:param cfg: The config dictionary.
|
||||
:type cfg: dict
|
||||
:return: A PrimaiteSession object.
|
||||
:rtype: PrimaiteSession
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
sim = sess.simulation
|
||||
# 1. create simulation
|
||||
sim = game.simulation
|
||||
net = sim.network
|
||||
|
||||
sess.ref_map_nodes: Dict[str, Node] = {}
|
||||
sess.ref_map_services: Dict[str, Service] = {}
|
||||
sess.ref_map_links: Dict[str, Link] = {}
|
||||
game.ref_map_nodes: Dict[str, Node] = {}
|
||||
game.ref_map_services: Dict[str, Service] = {}
|
||||
game.ref_map_links: Dict[str, Link] = {}
|
||||
|
||||
nodes_cfg = cfg["simulation"]["network"]["nodes"]
|
||||
links_cfg = cfg["simulation"]["network"]["links"]
|
||||
@@ -380,7 +253,7 @@ class PrimaiteSession:
|
||||
print(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
sess.ref_map_services[service_ref] = new_service
|
||||
game.ref_map_services[service_ref] = new_service
|
||||
else:
|
||||
print(f"service type not found {service_type}")
|
||||
# service-dependent options
|
||||
@@ -405,7 +278,7 @@ class PrimaiteSession:
|
||||
if application_type in application_types_mapping:
|
||||
new_node.software_manager.install(application_types_mapping[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
sess.ref_map_applications[application_ref] = new_application
|
||||
game.ref_map_applications[application_ref] = new_application
|
||||
else:
|
||||
print(f"application type not found {application_type}")
|
||||
if "nics" in node_cfg:
|
||||
@@ -414,16 +287,16 @@ class PrimaiteSession:
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
sess.ref_map_nodes[
|
||||
game.ref_map_nodes[
|
||||
node_ref
|
||||
] = (
|
||||
new_node.uuid
|
||||
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
|
||||
) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
@@ -433,11 +306,10 @@ class PrimaiteSession:
|
||||
else:
|
||||
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
|
||||
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
|
||||
# 3. create agents
|
||||
game_cfg = cfg["game_config"]
|
||||
agents_cfg = game_cfg["agents"]
|
||||
agents_cfg = cfg["agents"]
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"] # noqa: F841
|
||||
@@ -447,14 +319,14 @@ class PrimaiteSession:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, game)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
|
||||
if "node_ref" in action_node_option:
|
||||
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
|
||||
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
|
||||
action_space_cfg["options"]["node_uuids"].append(node_uuid)
|
||||
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
|
||||
# we will pass node_uuids as a part of the action space config.
|
||||
@@ -466,12 +338,12 @@ class PrimaiteSession:
|
||||
if "options" in action_config:
|
||||
if "target_router_ref" in action_config["options"]:
|
||||
_target = action_config["options"]["target_router_ref"]
|
||||
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
|
||||
action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target]
|
||||
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
@@ -482,7 +354,7 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
game.agents.append(new_agent)
|
||||
elif agent_type == "ProxyAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -490,8 +362,8 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agents.append(new_agent)
|
||||
game.agents.append(new_agent)
|
||||
game.rl_agents.append(new_agent)
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -499,16 +371,10 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
game.agents.append(new_agent)
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
|
||||
game._simulation_initial_state = deepcopy(game.simulation) # noqa
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
return game
|
||||
@@ -1,3 +0,0 @@
|
||||
from primaite.game.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy"]
|
||||
@@ -5,8 +5,8 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import load
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.config.load import example_config_path, load
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
@@ -42,6 +42,6 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.config:
|
||||
_LOGGER.error("Please provide a config file using the --config " "argument")
|
||||
args.config = example_config_path()
|
||||
|
||||
run(session_path=args.config)
|
||||
run(args.config)
|
||||
|
||||
127
src/primaite/notebooks/training_example_ray_multi_agent.ipynb
Normal file
127
src/primaite/notebooks/training_example_ray_multi_agent.ipynb
Normal file
@@ -0,0 +1,127 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"game = PrimaiteGame.from_config(cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# gym = PrimaiteRayEnv({\"game\":game})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ray\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ray.shutdown()\n",
|
||||
"ray.init()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"env_config = {\"game\":game}\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .multi_agent(\n",
|
||||
" policies={agent.agent_name for agent in game.rl_agents},\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"training_iteration\": 128},\n",
|
||||
" checkpoint_config=air.CheckpointConfig(\n",
|
||||
" checkpoint_frequency=10,\n",
|
||||
" ),\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
122
src/primaite/notebooks/training_example_ray_single_agent.ipynb
Normal file
122
src/primaite/notebooks/training_example_ray_single_agent.ipynb
Normal file
@@ -0,0 +1,122 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"game = PrimaiteGame.from_config(cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteRayEnv({\"game\":game})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms import ppo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ray.shutdown()\n",
|
||||
"ray.init()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env_config = {\"game\":game}\n",
|
||||
"config = {\n",
|
||||
" \"env\" : PrimaiteRayEnv,\n",
|
||||
" \"env_config\" : env_config,\n",
|
||||
" \"disable_env_checking\": True,\n",
|
||||
" \"num_rollout_workers\": 0,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"algo = ppo.PPO(config=config)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(5):\n",
|
||||
" result = algo.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"algo.save(\"temp/deleteme\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
102
src/primaite/notebooks/training_example_sb3.ipynb
Normal file
102
src/primaite/notebooks/training_example_sb3.ipynb
Normal file
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.config.load import example_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"game = PrimaiteGame.from_config(cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(game=game)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3 import PPO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO('MlpPolicy', gym)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=1000)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
0
src/primaite/session/__init__.py
Normal file
0
src/primaite/session/__init__.py
Normal file
162
src/primaite/session/environment.py
Normal file
162
src/primaite/session/environment.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.agent: ProxyAgent = self.game.rl_agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.game.reset()
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game=env_config["game"])
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Optional[Dict] = None) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.game: PrimaiteGame = env_config["game"]
|
||||
"""Reference to the primaite game"""
|
||||
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
|
||||
"""List of all possible agents in the environment. This list should not change!"""
|
||||
self._agent_ids = list(self.agents.keys())
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{name: agent.observation_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
self.game.reset()
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}
|
||||
4
src/primaite/session/policy/__init__.py
Normal file
4
src/primaite/session/policy/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from primaite.session.policy.rllib import RaySingleAgentPolicy
|
||||
from primaite.session.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
@@ -80,5 +80,3 @@ class PolicyABC(ABC):
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
|
||||
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass
|
||||
106
src/primaite/session/policy/rllib.py
Normal file
106
src/primaite/session/policy/rllib.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
"""Single agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
|
||||
config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": {"game": session.game},
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
self._algo = ppo.PPO(config=config)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
for ep in range(n_episodes):
|
||||
self._algo.train()
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
for ep in range(n_episodes):
|
||||
obs, info = self.session.env.reset()
|
||||
for step in range(self.session.game.options.max_episode_length):
|
||||
action = self._algo.compute_single_action(observation=obs, explore=False)
|
||||
obs, rew, term, trunc, info = self.session.env.step(action)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the policy to a file."""
|
||||
self._algo.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
|
||||
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
|
||||
"""Mutli agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
|
||||
"""Initialise multi agent policy wrapper."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in session.game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": n_episodes * timesteps_per_episode},
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
|
||||
),
|
||||
param_space=self.config,
|
||||
).fit()
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
return NotImplemented
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate trained policy."""
|
||||
return NotImplemented
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save policy parameters to a file."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
|
||||
"""Create policy from config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -8,10 +8,10 @@ from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
113
src/primaite/session/session.py
Normal file
113
src/primaite/session/session.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
|
||||
# from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
|
||||
"""The environment that the RL algorithm can consume."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game: PrimaiteGame = game
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary."""
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
sess = cls(game=game)
|
||||
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config={"game": game})
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game=game)
|
||||
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
@@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network:
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
|
||||
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
|
||||
# Client 2
|
||||
client_2 = Computer(
|
||||
|
||||
@@ -2,13 +2,14 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
"""
|
||||
@@ -148,21 +149,6 @@ class DatabaseClient(Application):
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
|
||||
def _print_data(self, data: Dict):
|
||||
"""
|
||||
Display the contents of the Folder in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
if data:
|
||||
table = PrettyTable(list(data.values())[0])
|
||||
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Database Client"
|
||||
for row in data.values():
|
||||
table.add_row(row.values())
|
||||
print(table)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
@@ -179,5 +165,5 @@ class DatabaseClient(Application):
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
self._print_data(payload["data"])
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
return True
|
||||
|
||||
@@ -100,8 +100,16 @@ class SoftwareManager:
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
for key, value in self.port_protocol_mapping.items():
|
||||
if value.name == software_name:
|
||||
self.port_protocol_mapping.pop(key)
|
||||
break
|
||||
for key, value in self._software_class_to_name_map.items():
|
||||
if value == software_name:
|
||||
self._software_class_to_name_map.pop(key)
|
||||
break
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
self.sys_log.info(f"Uninstalled {software_name}")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -19,7 +15,7 @@ class DatabaseService(Service):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
@@ -41,38 +37,6 @@ class DatabaseService(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
self._conn = sqlite3.connect(self._db_file.sim_path)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def tables(self) -> List[str]:
|
||||
"""
|
||||
Get a list of table names present in the database.
|
||||
|
||||
:return: List of table names.
|
||||
"""
|
||||
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
|
||||
results = self._process_sql(sql, None)
|
||||
if isinstance(results["data"], dict):
|
||||
return list(results["data"].keys())
|
||||
return []
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Prints a list of table names in the database using PrettyTable.
|
||||
|
||||
:param markdown: Whether to output the table in Markdown format.
|
||||
"""
|
||||
table = PrettyTable(["Table"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.file_system.sys_log.hostname} Database"
|
||||
for row in self.tables():
|
||||
table.add_row([row])
|
||||
print(table)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
@@ -89,8 +53,6 @@ class DatabaseService(Service):
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
self._conn.close()
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
|
||||
@@ -98,12 +60,10 @@ class DatabaseService(Service):
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self._db_file.folder.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
real_file_path=self._db_file.sim_path,
|
||||
)
|
||||
self._connect()
|
||||
|
||||
if response:
|
||||
return True
|
||||
@@ -125,25 +85,29 @@ class DatabaseService(Service):
|
||||
dest_ip_address=self.backup_server,
|
||||
)
|
||||
|
||||
if response:
|
||||
self._conn.close()
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self._connect()
|
||||
if not response:
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
|
||||
return self._db_file is not None
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
if self._db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
return True
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
@@ -163,31 +127,32 @@ class DatabaseService(Service):
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
|
||||
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
Possible queries:
|
||||
- SELECT : returns the data
|
||||
- DELETE : deletes the data
|
||||
|
||||
:param query: The SQL query to be executed.
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
self.sys_log.error(f"{self.name}: Error, query failed")
|
||||
return {"status_code": 404, "data": {}}
|
||||
data = []
|
||||
description = self._cursor.description
|
||||
if description:
|
||||
headers = []
|
||||
for header in description:
|
||||
headers.append(header[0])
|
||||
data = self._cursor.fetchall()
|
||||
if data and headers:
|
||||
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
|
||||
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,7 @@ class WebServer(Service):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
|
||||
# get all users
|
||||
if db_client.query("SELECT * FROM user;"):
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
response.status_code = HttpStatusCode.OK
|
||||
|
||||
|
||||
Reference in New Issue
Block a user