Merge branch 'dev' into feature/2068-Validate_documentation

This commit is contained in:
Nick Todd
2023-11-24 15:27:33 +00:00
41 changed files with 5925 additions and 3085 deletions

View File

@@ -7,207 +7,28 @@
Run a PrimAITE Session
======================
``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file,
no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment,
policy, simulator, agents, and IO.
If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment`
module directly without running a session.
Run
---
A PrimAITE session can be ran either with the ``primaite session`` command from the cli
A PrimAITE session can be started either with the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
.. note::
🚧 *UNDER CONSTRUCTION* 🚧
There are two parameters that can be specified:
- ``--config``: The path to the config file to use. If not specified, the default config file is used.
- ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created.
..
.. code-block:: bash
:caption: Unix CLI
Outputs
-------
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to training config yaml file>
lay_down_config = <path to lay down config yaml file>
run(training_config, lay_down_config)
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``).
The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``.
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options.
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to legacy training config yaml file>
lay_down_config = <path to legacy lay down config yaml file>
run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True)
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents (training checkpoints and a final trained agent)
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **learning**
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **evaluation**
* **total_episodes** - The total number of evaluation episodes completed.
* **total_time_steps** - The total number of evaluation time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each learning and evaluation session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
**Example Session Directory Structure**
.. code-block:: text
~/
└── primaite/
└── 2.0.0/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
│ │ └── sb3ppo_10.zip
│ ├── SB3_PPO.zip
│ └── tensorboard_logs/
│ ├── PPO_1/
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
│ ├── PPO_2/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
│ ├── PPO_3/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
│ ├── PPO_4/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
│ ├── PPO_5/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
│ ├── PPO_6/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
│ ├── PPO_7/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
│ ├── PPO_8/
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
│ ├── PPO_9/
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Running a session creates a session output directory in your user data folder. The filepath looks like this:
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
the saved agent checkpoints, and final model.

View File

@@ -43,7 +43,7 @@ Example
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
data_manipulation_bot.run()
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.

View File

@@ -14,9 +14,11 @@ The ``DatabaseService`` provides a SQL database server simulation by extending t
Key capabilities
^^^^^^^^^^^^^^^^
- Initialises a SQLite database file in the ``Node`` 's ``FileSystem`` upon creation.
- Creates a database file in the ``Node`` 's ``FileSystem`` upon creation.
- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs.
- Authenticates connections using a configurable password.
- Simulates ``SELECT`` and ``DELETE`` SQL queries.
- Returns query results and status codes back to clients.
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
@@ -28,9 +30,9 @@ Usage
Implementation
^^^^^^^^^^^^^^
- Uses SQLite for persistent storage.
- Creates the database file within the node's file system.
- Manages client connections in a dictionary by session ID.
- Processes SQL queries.
- Returns results and status codes in a standard dictionary format.
- Extends Service class for integration with ``SoftwareManager``.

View File

@@ -39,6 +39,7 @@ dependencies = [
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.1.1",
"ray[rllib] == 2.8.0, < 3"
]
[tool.setuptools.dynamic]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractAction(ABC):
@@ -43,7 +43,7 @@ class AbstractAction(ABC):
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
align with the keyword args of the form_request method."""
self.manager: ActionManager = manager
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
"""Reference to the ActionManager which created this action. This is used to access the game and simulation
objects."""
@abstractmethod
@@ -559,7 +559,7 @@ class ActionManager:
def __init__(
self,
session: "PrimaiteSession", # reference to session for looking up stuff
game: "PrimaiteGame", # reference to game for information lookup
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
max_folders_per_node: int = 2, # allows calculating shape
@@ -574,8 +574,8 @@ class ActionManager:
) -> None:
"""Init method for ActionManager.
:param session: Reference to the session to which the agent belongs.
:type session: PrimaiteSession
:param game: Reference to the game to which the agent belongs.
:type game: PrimaiteGame
:param actions: List of action types which should be made available to the agent.
:type actions: List[str]
:param node_uuids: List of node UUIDs that this agent can act on.
@@ -599,8 +599,8 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.session: "PrimaiteSession" = session
self.sim: Simulation = self.session.simulation
self.game: "PrimaiteGame" = game
self.sim: Simulation = self.game.simulation
self.node_uuids: List[str] = node_uuids
self.protocols: List[str] = protocols
self.ports: List[str] = ports
@@ -826,7 +826,7 @@ class ActionManager:
return nics[nic_idx]
@classmethod
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config definition.
@@ -845,20 +845,20 @@ class ActionManager:
These options are used to calculate the shape of the action space, and to provide additional information
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
:param session: The Primaite Session to which the agent belongs.
:type session: PrimaiteSession
:param game: The Primaite Game to which the agent belongs.
:type game: PrimaiteGame
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
obj = cls(
session=session,
game=game,
actions=cfg["action_list"],
# node_uuids=cfg["options"]["node_uuids"],
**cfg["options"],
protocols=session.options.protocols,
ports=session.options.ports,
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=None,
act_map=cfg.get("action_map"),
)

View File

@@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractObservation(ABC):
@@ -37,10 +37,10 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
"""
pass
@@ -91,13 +91,13 @@ class FileObservation(AbstractObservation):
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param session: _description_
:type session: PrimaiteSession
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
@@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
@@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation):
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
@@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
@@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation):
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
@@ -376,15 +376,13 @@ class NicObservation(AbstractObservation):
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
) -> "NicObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
:type parent_where: Optional[List[str]]
@@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation):
def from_config(
cls,
config: Dict,
session: "PrimaiteSession",
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
@@ -526,8 +524,8 @@ class NodeObservation(AbstractObservation):
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
@@ -543,24 +541,24 @@ class NodeObservation(AbstractObservation):
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = session.ref_map_nodes[config["node_ref"]]
node_uuid = game.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
@@ -694,13 +692,13 @@ class AclObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
@@ -709,15 +707,15 @@ class AclObservation(AbstractObservation):
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
nic_num = ip_map_config["nic_num"]
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
router_uuid = game.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -740,7 +738,7 @@ class NullObservation(AbstractObservation):
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
@@ -836,14 +834,14 @@ class UC2BlueObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
@@ -855,7 +853,7 @@ class UC2BlueObservation(AbstractObservation):
nodes = [
NodeObservation.from_config(
config=n,
session=session,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
@@ -865,13 +863,13 @@ class UC2BlueObservation(AbstractObservation):
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, session=session)
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
@@ -907,17 +905,17 @@ class UC2RedObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
@@ -966,7 +964,7 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
@@ -974,14 +972,14 @@ class ObservationManager:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")

View File

@@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractReward:
@@ -47,13 +47,13 @@ class AbstractReward:
@classmethod
@abstractmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: AbstractReward
"""
@@ -68,13 +68,13 @@ class DummyReward(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
"""
return cls()
@@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward):
return 0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
@@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward):
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = session.ref_map_nodes[node_ref]
node_uuid = game.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
@@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward component.
:rtype: WebServer404Penalty
"""
@@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward):
)
_LOGGER.warn(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = session.ref_map_nodes[node_ref]
service_uuid = session.ref_map_services[service_ref].uuid
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref].uuid
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
@@ -265,13 +265,13 @@ class RewardFunction:
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction":
"""Create a reward function from a config dictionary.
:param config: dict of options for the reward manager's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:param game: Reference to the PrimAITE Game object
:type game: PrimaiteGame
:return: The reward manager.
:rtype: RewardFunction
"""
@@ -281,6 +281,6 @@ class RewardFunction:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.__rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game)
new.regsiter_component(component=rew_instance, weight=weight)
return new

View File

@@ -1,11 +1,8 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from enum import Enum
"""PrimAITE game - Encapsulates the simulation and agents."""
from copy import deepcopy
from ipaddress import IPv4Address
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
from typing import Dict, List
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
@@ -13,8 +10,6 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -36,65 +31,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
class PrimaiteGameOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
@@ -103,43 +40,26 @@ class PrimaiteSessionOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
max_episode_length: int = 256
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
class PrimaiteGame:
"""
Primaite game encapsulates the simulation and agents which interact with it.
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment.
"""
def __init__(self):
"""Initialise a PrimaiteSession object."""
"""Initialise a PrimaiteGame object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -152,15 +72,9 @@ class PrimaiteSession:
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
@@ -173,40 +87,6 @@ class PrimaiteSession:
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -225,7 +105,7 @@ class PrimaiteSession:
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
@@ -267,29 +147,29 @@ class PrimaiteSession:
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
max_steps = self.options.max_episode_length
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
"""Reset the game, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
"""Close the game, this will close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
2. game_config: options for the game itself. Used by PrimaiteGame.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
@@ -297,26 +177,19 @@ class PrimaiteSession:
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
sess = cls()
sess.options = PrimaiteSessionOptions(
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
# 1. create simulation
sim = game.simulation
net = sim.network
sess.ref_map_nodes: Dict[str, Node] = {}
sess.ref_map_services: Dict[str, Service] = {}
sess.ref_map_links: Dict[str, Link] = {}
game.ref_map_nodes: Dict[str, Node] = {}
game.ref_map_services: Dict[str, Service] = {}
game.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
@@ -380,7 +253,7 @@ class PrimaiteSession:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
sess.ref_map_services[service_ref] = new_service
game.ref_map_services[service_ref] = new_service
else:
print(f"service type not found {service_type}")
# service-dependent options
@@ -405,7 +278,7 @@ class PrimaiteSession:
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
game.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
if "nics" in node_cfg:
@@ -414,16 +287,16 @@ class PrimaiteSession:
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[
game.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
@@ -433,11 +306,10 @@ class PrimaiteSession:
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
# 3. create agents
game_cfg = cfg["game_config"]
agents_cfg = game_cfg["agents"]
agents_cfg = cfg["agents"]
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"] # noqa: F841
@@ -447,14 +319,14 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, game)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
@@ -466,12 +338,12 @@ class PrimaiteSession:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
action_space = ActionManager.from_config(game, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
@@ -482,7 +354,7 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -490,8 +362,8 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agents.append(new_agent)
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
@@ -499,16 +371,10 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
game.agents.append(new_agent)
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
game._simulation_initial_state = deepcopy(game.simulation) # noqa
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess
return game

View File

@@ -1,3 +0,0 @@
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]

View File

@@ -5,8 +5,8 @@ from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import load
from primaite.game.session import PrimaiteSession
from primaite.config.load import example_config_path, load
from primaite.session.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
@@ -42,6 +42,6 @@ if __name__ == "__main__":
args = parser.parse_args()
if not args.config:
_LOGGER.error("Please provide a config file using the --config " "argument")
args.config = example_config_path()
run(session_path=args.config)
run(args.config)

View File

@@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"\n",
"\n",
"env_config = {\"game\":game}\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n",
" .rollouts(num_rollout_workers=0)\n",
" .multi_agent(\n",
" policies={agent.agent_name for agent in game.rl_agents},\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .training(train_batch_size=128)\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"training_iteration\": 128},\n",
" checkpoint_config=air.CheckpointConfig(\n",
" checkpoint_frequency=10,\n",
" ),\n",
" ),\n",
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,122 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray.rllib.algorithms import ppo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"game\":game}\n",
"config = {\n",
" \"env\" : PrimaiteRayEnv,\n",
" \"env_config\" : env_config,\n",
" \"disable_env_checking\": True,\n",
" \"num_rollout_workers\": 0,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo = ppo.PPO(config=config)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(5):\n",
" result = algo.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo.save(\"temp/deleteme\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.config.load import example_config_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game=game)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=1000)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

View File

@@ -0,0 +1,162 @@
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game: PrimaiteGame):
"""Initialise the environment."""
super().__init__()
self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = self.game.rl_agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.game.reset()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Optional[Dict] = None) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.game: PrimaiteGame = env_config["game"]
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
self._agent_ids = list(self.agents.keys())
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{name: agent.observation_manager.space for name, agent in self.agents.items()}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
super().__init__()
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game.reset()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
def step(
self, actions: Dict[str, ActType]
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
:type actions: Dict[str, ActType]
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
# 3. Get next observations
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
return next_obs, rewards, terminateds, truncateds, infos
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}

View File

@@ -0,0 +1,4 @@
from primaite.session.policy.rllib import RaySingleAgentPolicy
from primaite.session.policy.sb3 import SB3Policy
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
@@ -80,5 +80,3 @@ class PolicyABC(ABC):
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -0,0 +1,106 @@
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
ray.shutdown()
ray.init()
self._algo = ppo.PPO(config=config)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
for ep in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(self.session.game.options.max_episode_length):
action = self._algo.compute_single_action(observation=obs, explore=False)
obs, rew, term, trunc, info = self.session.env.step(action)
def save(self, save_path: Path) -> None:
"""Save the policy to a file."""
self._algo.save(save_path)
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -8,10 +8,10 @@ from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):

View File

@@ -0,0 +1,113 @@
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game: PrimaiteGame):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
"""The environment that the RL algorithm can consume."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager = SessionIO()
"""IO manager for the session."""
self.game: PrimaiteGame = game
"""Primaite Game object for managing main simulation loop and agents."""
def start_session(self) -> None:
"""Commence the training/eval session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"game": game})
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game=game)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess

View File

@@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network:
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
# Client 2
client_2 = Computer(

View File

@@ -2,13 +2,14 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from uuid import uuid4
from prettytable import PrettyTable
from primaite import getLogger
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.software_manager import SoftwareManager
_LOGGER = getLogger(__name__)
class DatabaseClient(Application):
"""
@@ -148,21 +149,6 @@ class DatabaseClient(Application):
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
def _print_data(self, data: Dict):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
if data:
table = PrettyTable(list(data.values())[0])
table.align = "l"
table.title = f"{self.sys_log.hostname} Database Client"
for row in data.values():
table.add_row(row.values())
print(table)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
@@ -179,5 +165,5 @@ class DatabaseClient(Application):
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
self._print_data(payload["data"])
_LOGGER.debug(f"Received payload {payload}")
return True

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")

View File

@@ -1,10 +1,6 @@
import sqlite3
from datetime import datetime
from ipaddress import IPv4Address
from sqlite3 import OperationalError
from typing import Any, Dict, List, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from typing import Any, Dict, List, Literal, Optional, Union
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -19,7 +15,7 @@ class DatabaseService(Service):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
password: Optional[str] = None
@@ -41,38 +37,6 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
self._connect()
def _connect(self):
self._conn = sqlite3.connect(self._db_file.sim_path)
self._cursor = self._conn.cursor()
def tables(self) -> List[str]:
"""
Get a list of table names present in the database.
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql, None)
if isinstance(results["data"], dict):
return list(results["data"].keys())
return []
def show(self, markdown: bool = False):
"""
Prints a list of table names in the database using PrettyTable.
:param markdown: Whether to output the table in Markdown format.
"""
table = PrettyTable(["Table"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.file_system.sys_log.hostname} Database"
for row in self.tables():
table.add_row([row])
print(table)
def configure_backup(self, backup_server: IPv4Address):
"""
@@ -89,8 +53,6 @@ class DatabaseService(Service):
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
self._conn.close()
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
@@ -98,12 +60,10 @@ class DatabaseService(Service):
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server,
src_file_name=self._db_file.name,
src_folder_name=self._db_file.folder.name,
src_folder_name=self.folder.name,
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
real_file_path=self._db_file.sim_path,
)
self._connect()
if response:
return True
@@ -125,25 +85,29 @@ class DatabaseService(Service):
dest_ip_address=self.backup_server,
)
if response:
self._conn.close()
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self._connect()
if not response:
self.sys_log.error("Unable to restore database backup.")
return False
return self._db_file is not None
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self.sys_log.error("Unable to restore database backup.")
return False
if self._db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
self.set_health_state(SoftwareHealthState.GOOD)
return True
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
self.folder = self._db_file.folder
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
@@ -163,31 +127,32 @@ class DatabaseService(Service):
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
Possible queries:
- SELECT : returns the data
- DELETE : deletes the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
try:
self._cursor.execute(query)
self._conn.commit()
except OperationalError:
# Handle the case where the table does not exist.
self.sys_log.error(f"{self.name}: Error, query failed")
return {"status_code": 404, "data": {}}
data = []
description = self._cursor.description
if description:
headers = []
for header in description:
headers.append(header[0])
data = self._cursor.fetchall()
if data and headers:
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
if query == "SELECT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
if self.health_state_actual == SoftwareHealthState.GOOD:
self.health_state_actual = SoftwareHealthState.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
else:
# Invalid query
return {"status_code": 500, "data": False}
def describe_state(self) -> Dict:
"""

View File

@@ -106,7 +106,7 @@ class WebServer(Service):
# get data from DatabaseServer
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
# get all users
if db_client.query("SELECT * FROM user;"):
if db_client.query("SELECT"):
# query succeeded
response.status_code = HttpStatusCode.OK

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,13 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
import nodeenv
import pytest
import yaml
from primaite import getLogger
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession

View File

@@ -0,0 +1,45 @@
import pytest
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayMARLEnv
@pytest.mark.skip(reason="Slow, reenable later")
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()

View File

@@ -0,0 +1,40 @@
import tempfile
from pathlib import Path
import pytest
import ray
import yaml
from ray.rllib.algorithms import ppo
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayEnv
@pytest.mark.skip(reason="Slow, reenable later")
def test_rllib_single_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = {
"env": PrimaiteRayEnv,
"env_config": env_config,
"disable_env_checking": True,
"num_rollout_workers": 0,
}
algo = ppo.PPO(config=config)
for i in range(5):
result = algo.train()
save_file = Path(tempfile.gettempdir()) / "ray/"
algo.save(save_file)
assert save_file.exists()

View File

@@ -0,0 +1,27 @@
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
import tempfile
from pathlib import Path
import yaml
from stable_baselines3 import PPO
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
gym = PrimaiteGymEnv(game=game)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path)
assert (save_path).exists()

View File

@@ -1,12 +1,14 @@
import pydantic
import pytest
from tests import TEST_ASSETS_ROOT
from tests.conftest import TempPrimaiteSession
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
class TestPrimaiteSession:
@@ -18,15 +20,15 @@ class TestPrimaiteSession:
raise AssertionError
assert session is not None
assert session.simulation
assert len(session.agents) == 3
assert len(session.rl_agents) == 1
assert session.game.simulation
assert len(session.game.agents) == 3
assert len(session.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.simulation.network
assert len(session.simulation.network.nodes) == 10
assert session.game.simulation.network
assert len(session.game.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):
@@ -63,6 +65,27 @@ class TestPrimaiteSession:
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
@pytest.mark.skip(reason="Slow, reenable later")
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
def test_multi_agent_session(self, temp_primaite_session):
"""Check that we can run a training session with a multi agent system."""
with temp_primaite_session as session:
session.start_session()
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.game.reset()
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software

View File

@@ -19,16 +19,16 @@ def test_data_manipulation(uc2_network):
db_service.backup_database()
# First check that the DB client on the web_server can successfully query the users table on the database
assert db_client.query("SELECT * FROM user;")
assert db_client.query("SELECT")
# Now we run the DataManipulationBot
db_manipulation_bot.run()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT * FROM user;")
assert not db_client.query("SELECT")
# Now restore the database
db_service.restore_backup()
# Now check that the DB client on the web_server can successfully query the users table on the database
assert db_client.query("SELECT * FROM user;")
assert db_client.query("SELECT")

View File

@@ -57,7 +57,7 @@ def test_database_client_query(uc2_network):
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query("SELECT * FROM user;")
assert db_client.query("SELECT")
def test_create_database_backup(uc2_network):

View File

@@ -17,4 +17,4 @@ def test_creation():
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;"
assert data_manipulation_bot.payload == "DELETE"