Resolve merge conflicts

This commit is contained in:
Marek Wolan
2023-11-26 23:29:14 +00:00
parent cbdaa6c444
commit ece9b14d63
33 changed files with 6074 additions and 3068 deletions

View File

@@ -7,207 +7,28 @@
Run a PrimAITE Session
======================
``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file,
no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment,
policy, simulator, agents, and IO.
If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment`
module directly without running a session.
Run
---
A PrimAITE session can be ran either with the ``primaite session`` command from the cli
A PrimAITE session can be started either with the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
.. note::
🚧 *UNDER CONSTRUCTION* 🚧
There are two parameters that can be specified:
- ``--config``: The path to the config file to use. If not specified, the default config file is used.
- ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created.
..
.. code-block:: bash
:caption: Unix CLI
Outputs
-------
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to training config yaml file>
lay_down_config = <path to lay down config yaml file>
run(training_config, lay_down_config)
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``).
The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``.
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options.
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to legacy training config yaml file>
lay_down_config = <path to legacy lay down config yaml file>
run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True)
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents (training checkpoints and a final trained agent)
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **learning**
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **evaluation**
* **total_episodes** - The total number of evaluation episodes completed.
* **total_time_steps** - The total number of evaluation time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each learning and evaluation session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
**Example Session Directory Structure**
.. code-block:: text
~/
└── primaite/
└── 2.0.0/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
│ │ └── sb3ppo_10.zip
│ ├── SB3_PPO.zip
│ └── tensorboard_logs/
│ ├── PPO_1/
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
│ ├── PPO_2/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
│ ├── PPO_3/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
│ ├── PPO_4/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
│ ├── PPO_5/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
│ ├── PPO_6/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
│ ├── PPO_7/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
│ ├── PPO_8/
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
│ ├── PPO_9/
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Running a session creates a session output directory in your user data folder. The filepath looks like this:
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
the saved agent checkpoints, and final model.

View File

@@ -39,6 +39,7 @@ dependencies = [
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.1.1",
"ray[rllib] == 2.8.0, < 3"
]
[tool.setuptools.dynamic]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractAction(ABC):
@@ -43,7 +43,7 @@ class AbstractAction(ABC):
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
align with the keyword args of the form_request method."""
self.manager: ActionManager = manager
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
"""Reference to the ActionManager which created this action. This is used to access the game and simulation
objects."""
@abstractmethod
@@ -591,7 +591,7 @@ class ActionManager:
def __init__(
self,
session: "PrimaiteSession", # reference to session for looking up stuff
game: "PrimaiteGame", # reference to game for information lookup
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
application_uuids: List[List[str]], # allows mapping index to application
@@ -608,8 +608,8 @@ class ActionManager:
) -> None:
"""Init method for ActionManager.
:param session: Reference to the session to which the agent belongs.
:type session: PrimaiteSession
:param game: Reference to the game to which the agent belongs.
:type game: PrimaiteGame
:param actions: List of action types which should be made available to the agent.
:type actions: List[str]
:param node_uuids: List of node UUIDs that this agent can act on.
@@ -633,8 +633,8 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.session: "PrimaiteSession" = session
self.sim: Simulation = self.session.simulation
self.game: "PrimaiteGame" = game
self.sim: Simulation = self.game.simulation
self.node_uuids: List[str] = node_uuids
self.application_uuids: List[List[str]] = application_uuids
self.protocols: List[str] = protocols
@@ -874,7 +874,7 @@ class ActionManager:
return nics[nic_idx]
@classmethod
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config definition.
@@ -893,20 +893,20 @@ class ActionManager:
These options are used to calculate the shape of the action space, and to provide additional information
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
:param session: The Primaite Session to which the agent belongs.
:type session: PrimaiteSession
:param game: The Primaite Game to which the agent belongs.
:type game: PrimaiteGame
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
obj = cls(
session=session,
game=game,
actions=cfg["action_list"],
# node_uuids=cfg["options"]["node_uuids"],
**cfg["options"],
protocols=session.options.protocols,
ports=session.options.ports,
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=None,
act_map=cfg.get("action_map"),
)

View File

@@ -38,7 +38,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.session.step_counter
current_timestep = self.action_manager.game.step_counter
if current_timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}

View File

@@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractObservation(ABC):
@@ -37,10 +37,10 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
"""
pass
@@ -91,13 +91,13 @@ class FileObservation(AbstractObservation):
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param session: _description_
:type session: PrimaiteSession
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
@@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
@@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation):
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
@@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
@@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation):
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
@@ -376,15 +376,13 @@ class NicObservation(AbstractObservation):
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
) -> "NicObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
:type parent_where: Optional[List[str]]
@@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation):
def from_config(
cls,
config: Dict,
session: "PrimaiteSession",
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
@@ -526,8 +524,8 @@ class NodeObservation(AbstractObservation):
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
@@ -543,24 +541,24 @@ class NodeObservation(AbstractObservation):
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = session.ref_map_nodes[config["node_ref"]]
node_uuid = game.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
@@ -694,13 +692,13 @@ class AclObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
@@ -709,15 +707,15 @@ class AclObservation(AbstractObservation):
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
nic_num = ip_map_config["nic_num"]
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -740,7 +738,7 @@ class NullObservation(AbstractObservation):
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
@@ -836,14 +834,14 @@ class UC2BlueObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
@@ -855,7 +853,7 @@ class UC2BlueObservation(AbstractObservation):
nodes = [
NodeObservation.from_config(
config=n,
session=session,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
@@ -865,13 +863,13 @@ class UC2BlueObservation(AbstractObservation):
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, session=session)
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
@@ -907,17 +905,17 @@ class UC2RedObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
@@ -966,7 +964,7 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
@@ -974,14 +972,14 @@ class ObservationManager:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")

View File

@@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractReward:
@@ -47,13 +47,13 @@ class AbstractReward:
@classmethod
@abstractmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: AbstractReward
"""
@@ -68,13 +68,13 @@ class DummyReward(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
"""
return cls()
@@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward):
return 0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
@@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward):
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = session.ref_map_nodes[node_ref]
node_uuid = game.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
@@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: WebServer404Penalty
"""
@@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward):
)
_LOGGER.warn(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = session.ref_map_nodes[node_ref]
service_uuid = session.ref_map_services[service_ref].uuid
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref].uuid
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
@@ -265,13 +265,13 @@ class RewardFunction:
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction":
"""Create a reward function from a config dictionary.
:param config: dict of options for the reward manager's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward manager.
:rtype: RewardFunction
"""
@@ -281,6 +281,6 @@ class RewardFunction:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.__rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game)
new.regsiter_component(component=rew_instance, weight=weight)
return new

View File

@@ -1,12 +1,8 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
"""PrimAITE game - Encapsulates the simulation and agents."""
from copy import deepcopy
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
from typing import Dict, List
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
@@ -15,8 +11,6 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -40,65 +34,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
class PrimaiteGameOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
@@ -107,40 +43,20 @@ class PrimaiteSessionOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
max_episode_length: int = 256
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
class PrimaiteGame:
"""
Primaite game encapsulates the simulation and agents which interact with it.
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment.
"""
def __init__(self):
"""Initialise a PrimaiteSession object."""
"""Initialise a PrimaiteGame object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
@@ -159,15 +75,9 @@ class PrimaiteSession:
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
@@ -180,40 +90,6 @@ class PrimaiteSession:
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -232,7 +108,7 @@ class PrimaiteSession:
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
@@ -274,29 +150,29 @@ class PrimaiteSession:
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
max_steps = self.options.max_episode_length
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
"""Reset the game, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
"""Close the game, this will close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
2. game_config: options for the game itself. Used by PrimaiteGame.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
@@ -304,26 +180,19 @@ class PrimaiteSession:
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
sess = cls()
sess.options = PrimaiteSessionOptions(
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
# 1. create simulation
sim = game.simulation
net = sim.network
sess.ref_map_nodes: Dict[str, Node] = {}
sess.ref_map_services: Dict[str, Service] = {}
sess.ref_map_links: Dict[str, Link] = {}
game.ref_map_nodes: Dict[str, Node] = {}
game.ref_map_services: Dict[str, Service] = {}
game.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
@@ -400,7 +269,7 @@ class PrimaiteSession:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
sess.ref_map_services[service_ref] = new_service
game.ref_map_services[service_ref] = new_service
else:
print(f"service type not found {service_type}")
# service-dependent options
@@ -434,7 +303,7 @@ class PrimaiteSession:
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
game.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
@@ -442,7 +311,7 @@ class PrimaiteSession:
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=opt.get("server_ip"),
server_ip_address=IPv4Address(opt.get("server_ip")),
payload=opt.get("payload"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
@@ -453,7 +322,7 @@ class PrimaiteSession:
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[
game.ref_map_nodes[
node_ref
] = (
new_node.uuid
@@ -461,8 +330,8 @@ class PrimaiteSession:
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
@@ -472,13 +341,10 @@ class PrimaiteSession:
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
# endpoint_a.enable()
# endpoint_b.enable()
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
# 3. create agents
game_cfg = cfg["game_config"]
agents_cfg = game_cfg["agents"]
agents_cfg = cfg["agents"]
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"] # noqa: F841
@@ -488,7 +354,7 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, game)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
@@ -497,7 +363,7 @@ class PrimaiteSession:
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
if "applications" in action_node_option:
@@ -505,7 +371,7 @@ class PrimaiteSession:
for application_option in action_node_option["applications"]:
# TODO: fix inconsistency with node uuids and application uuids. The node object get added to
# node_uuid, whereas here the application gets added by uuid.
application_uuid = sess.ref_map_applications[application_option["application_ref"]].uuid
application_uuid = game.ref_map_applications[application_option["application_ref"]].uuid
node_application_uuids.append(application_uuid)
action_space_cfg["options"]["application_uuids"].append(node_application_uuids)
@@ -522,12 +388,12 @@ class PrimaiteSession:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
action_space = ActionManager.from_config(game, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
@@ -541,7 +407,7 @@ class PrimaiteSession:
reward_function=rew_function,
agent_settings=agent_settings,
)
sess.agents.append(new_agent)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -549,8 +415,8 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agents.append(new_agent)
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -559,18 +425,10 @@ class PrimaiteSession:
reward_function=rew_function,
agent_settings=agent_settings,
)
sess.agents.append(new_agent)
game.agents.append(new_agent)
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
game._simulation_initial_state = deepcopy(game.simulation) # noqa
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
return sess
return game

View File

@@ -1,3 +0,0 @@
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]

View File

@@ -5,8 +5,8 @@ from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import load
from primaite.game.session import PrimaiteSession
from primaite.config.load import example_config_path, load
from primaite.session.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
@@ -42,6 +42,6 @@ if __name__ == "__main__":
args = parser.parse_args()
if not args.config:
_LOGGER.error("Please provide a config file using the --config " "argument")
args.config = example_config_path()
run(session_path=args.config)
run(args.config)

View File

@@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"\n",
"\n",
"env_config = {\"game\":game}\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n",
" .rollouts(num_rollout_workers=0)\n",
" .multi_agent(\n",
" policies={agent.agent_name for agent in game.rl_agents},\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .training(train_batch_size=128)\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"training_iteration\": 128},\n",
" checkpoint_config=air.CheckpointConfig(\n",
" checkpoint_frequency=10,\n",
" ),\n",
" ),\n",
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray.rllib.algorithms import ppo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"game\":game}\n",
"config = {\n",
" \"env\" : PrimaiteRayEnv,\n",
" \"env_config\" : env_config,\n",
" \"disable_env_checking\": True,\n",
" \"num_rollout_workers\": 0,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo = ppo.PPO(config=config)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(5):\n",
" result = algo.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo.save(\"temp/deleteme\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.config.load import example_config_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game=game)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=1000)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n"
]
}
],
"source": [
"from primaite.session.session import PrimaiteSession\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.simulator.system.services.database.database_service import DatabaseService\n",
"\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"installing DNSServer on node domain_controller\n",
"installing DatabaseClient on node web_server\n",
"installing WebServer on node web_server\n",
"installing DatabaseService on node database_server\n",
"installing FTPClient on node database_server\n",
"installing FTPServer on node backup_server\n",
"installing DNSClient on node client_1\n",
"installing DNSClient on node client_2\n"
]
}
],
"source": [
"\n",
"with open(example_config_path(),'r') as cfgfile:\n",
" cfg = yaml.safe_load(cfgfile)\n",
"game = PrimaiteGame.from_config(cfg)\n",
"net = game.simulation.network\n",
"database_server = net.get_node_by_hostname('database_server')\n",
"web_server = net.get_node_by_hostname('web_server')\n",
"client_1 = net.get_node_by_hostname('client_1')\n",
"\n",
"db_service = database_server.software_manager.software[\"DatabaseService\"]\n",
"db_client = web_server.software_manager.software[\"DatabaseClient\"]\n",
"# db_client.run()\n",
"db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n",
"db_manipulation_bot.port_scan_p_of_success=1.0\n",
"db_manipulation_bot.data_manipulation_p_of_success=1.0\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"db_client.run()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_service.backup_database()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"db_manipulation_bot.run()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_service.restore_backup()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_manipulation_bot.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client_1.ping(database_server.ethernet_port[1].ip_address)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import validate_call, BaseModel"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class A(BaseModel):\n",
" x:int\n",
"\n",
" @validate_call\n",
" def increase_x(self, by:int) -> None:\n",
" self.x += 1"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"my_a = A(x=3)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell://wsl%2Bubuntu/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb#X23sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n",
"File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float"
]
}
],
"source": [
"my_a.increase_x(3.2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

View File

@@ -0,0 +1,162 @@
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game: PrimaiteGame):
"""Initialise the environment."""
super().__init__()
self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = self.game.rl_agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.game.reset()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Optional[Dict] = None) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.game: PrimaiteGame = env_config["game"]
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
self._agent_ids = list(self.agents.keys())
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{name: agent.observation_manager.space for name, agent in self.agents.items()}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
super().__init__()
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game.reset()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
def step(
self, actions: Dict[str, ActType]
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
:type actions: Dict[str, ActType]
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
# 3. Get next observations
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
return next_obs, rewards, terminateds, truncateds, infos
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}

View File

@@ -0,0 +1,4 @@
from primaite.session.policy.rllib import RaySingleAgentPolicy
from primaite.session.policy.sb3 import SB3Policy
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
@@ -80,5 +80,3 @@ class PolicyABC(ABC):
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -0,0 +1,106 @@
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
ray.shutdown()
ray.init()
self._algo = ppo.PPO(config=config)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
for ep in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(self.session.game.options.max_episode_length):
action = self._algo.compute_single_action(observation=obs, explore=False)
obs, rew, term, trunc, info = self.session.env.step(action)
def save(self, save_path: Path) -> None:
"""Save the policy to a file."""
self._algo.save(save_path)
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -8,10 +8,10 @@ from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):

View File

@@ -0,0 +1,113 @@
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game: PrimaiteGame):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
"""The environment that the RL algorithm can consume."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager = SessionIO()
"""IO manager for the session."""
self.game: PrimaiteGame = game
"""Primaite Game object for managing main simulation loop and agents."""
def start_session(self) -> None:
"""Commence the training/eval session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"game": game})
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game=game)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,8 @@ import pytest
import yaml
from primaite import getLogger
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession

View File

@@ -0,0 +1,45 @@
import pytest
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayMARLEnv
@pytest.mark.skip(reason="Slow, reenable later")
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()

View File

@@ -0,0 +1,40 @@
import tempfile
from pathlib import Path
import pytest
import ray
import yaml
from ray.rllib.algorithms import ppo
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayEnv
@pytest.mark.skip(reason="Slow, reenable later")
def test_rllib_single_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = {
"env": PrimaiteRayEnv,
"env_config": env_config,
"disable_env_checking": True,
"num_rollout_workers": 0,
}
algo = ppo.PPO(config=config)
for i in range(5):
result = algo.train()
save_file = Path(tempfile.gettempdir()) / "ray/"
algo.save(save_file)
assert save_file.exists()

View File

@@ -0,0 +1,27 @@
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
import tempfile
from pathlib import Path
import yaml
from stable_baselines3 import PPO
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
gym = PrimaiteGymEnv(game=game)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path)
assert (save_path).exists()

View File

@@ -8,6 +8,7 @@ CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
class TestPrimaiteSession:
@@ -19,15 +20,15 @@ class TestPrimaiteSession:
raise AssertionError
assert session is not None
assert session.simulation
assert len(session.agents) == 3
assert len(session.rl_agents) == 1
assert session.game.simulation
assert len(session.game.agents) == 3
assert len(session.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.simulation.network
assert len(session.simulation.network.nodes) == 10
assert session.game.simulation.network
assert len(session.game.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):
@@ -64,6 +65,13 @@ class TestPrimaiteSession:
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
@pytest.mark.skip(reason="Slow, reenable later")
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
def test_multi_agent_session(self, temp_primaite_session):
"""Check that we can run a training session with a multi agent system."""
with temp_primaite_session as session:
session.start_session()
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@@ -72,12 +80,12 @@ class TestPrimaiteSession:
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.simulation.network.get_node_by_hostname("client_1")
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.reset()
client_1 = session.simulation.network.get_node_by_hostname("client_1")
session.game.reset()
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software