Merged PR 195: Initial game layer functionality
## Summary * add a new module `game` to PrimAITE * this includes a new PrimaiteSession which creates a simulation, and multiple agents, and talks to GATE * agent interface * agent actions to work with Simulator requests * agent observations that work with Simulator State * agent rewards also with Simulator state **note** that this branch is currently still in a broken state. Still need to do things like updating readme, install instructions, refactoring some poorly designed classes, and removing legacy code. These will be done in subsequent PRs to avoid making this PR even bigger than it needs to be. Still, please review this to familiarise yourself. ## Test process Some unit tests exist but their coverage will be expanded. I performed some test runs with to train a SB3 agent in GATE with a primaite simulation. ## Checklist - [y] PR is linked to a **work item** - [y] **acceptance criteria** of linked ticket are met - [~] performed **self-review** of the code - [n] written **tests** for any new functionality added with this PR - [~] updated the **documentation** if this PR changes or adds functionality - [~] written/updated **design docs** if this PR implements new functionality - [~] updated the **change log** - [y] ran **pre-commit** checks for code style - [n] attended to any **TO-DOs** left in the code Related work items: #1622, #1759, #1760, #1761, #1764, #1765, #1766, #1767, #1768, #1879, #1924
This commit is contained in:
@@ -98,7 +98,9 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
source/getting_started
|
||||
source/about
|
||||
source/config
|
||||
source/config(v2)
|
||||
source/simulation
|
||||
source/game_layer
|
||||
source/primaite_session
|
||||
source/custom_agent
|
||||
PrimAITE API <source/_autosummary/primaite>
|
||||
|
||||
491
docs/source/config(v2).rst
Normal file
491
docs/source/config(v2).rst
Normal file
@@ -0,0 +1,491 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _config:
|
||||
|
||||
The Config Files Explained
|
||||
==========================
|
||||
|
||||
Note: This file describes the config files used in legacy PrimAITE v2.0. This file will be removed soon.
|
||||
|
||||
PrimAITE uses two configuration files for its operation:
|
||||
|
||||
* **The Training Config**
|
||||
|
||||
Used to define the top-level settings of the PrimAITE environment, the reward values, and the session that is to be run.
|
||||
|
||||
* **The Lay Down Config**
|
||||
|
||||
Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules.
|
||||
|
||||
Training Config:
|
||||
*******************
|
||||
|
||||
The Training Config file consists of the following attributes:
|
||||
|
||||
**Generic Config Values**
|
||||
|
||||
|
||||
* **agent_framework** [enum]
|
||||
|
||||
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
|
||||
|
||||
* NONE - Where a user developed agent is to be used
|
||||
* SB3 - Stable Baselines3
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **agent_identifier**
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* A2C - Advantage Actor Critic
|
||||
* PPO - Proximal Policy Optimization
|
||||
* HARDCODED - A custom built deterministic agent
|
||||
* RANDOM - A Stochastic random agent
|
||||
|
||||
|
||||
* **random_red_agent** [bool]
|
||||
|
||||
Determines if the session should be run with a random red agent
|
||||
|
||||
* **action_type** [enum]
|
||||
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
|
||||
|
||||
* **OBSERVATION_SPACE** [dict]
|
||||
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is in :py:mod:`primaite.environment.observations`.
|
||||
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
|
||||
This example illustrates the correct format for the observation space config item
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
combine_service_traffic : False
|
||||
quantisation_levels: 99
|
||||
|
||||
|
||||
Currently available components are:
|
||||
|
||||
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
|
||||
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
|
||||
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
|
||||
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
|
||||
|
||||
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
|
||||
* ``quantisation_levels`` - how many discrete bandwidth usage levels to use for encoding. This can be an integer equal to or greater than 3.
|
||||
|
||||
The other configurable item is ``flatten`` which is false by default. When set to true, the observation space is flattened (turned into a 1-D vector). You should use this if your RL agent does not natively support observation space types like ``gym.Spaces.Tuple``.
|
||||
|
||||
* **num_train_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will train for.
|
||||
|
||||
|
||||
* **num_train_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the training session.
|
||||
|
||||
|
||||
* **num_eval_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will be evaluated over.
|
||||
|
||||
|
||||
* **num_eval_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the evaluation session.
|
||||
|
||||
|
||||
* **time_delay** [int]
|
||||
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
|
||||
|
||||
* **session_type** [text]
|
||||
|
||||
Type of session to be run (TRAINING, EVALUATION, or BOTH)
|
||||
|
||||
* **load_agent** [bool]
|
||||
|
||||
Determine whether to load an agent from file
|
||||
|
||||
* **agent_load_file** [text]
|
||||
|
||||
File path and file name of agent if you're loading one in
|
||||
|
||||
* **observation_space_high_value** [int]
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
* **implicit_acl_rule** [str]
|
||||
|
||||
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
|
||||
|
||||
* **max_number_acl_rules** [int]
|
||||
|
||||
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
|
||||
* **Generic [all_ok]** [float]
|
||||
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
|
||||
* **Node Hardware State [off_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is off
|
||||
|
||||
* **Node Hardware State [off_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Hardware State [on_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Hardware State [on_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Hardware State [resetting_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting]** [float]
|
||||
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [float]
|
||||
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [float]
|
||||
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [float]
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [good_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is good
|
||||
|
||||
* **Node File System State [good_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is good
|
||||
|
||||
* **Node File System State [good_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is good
|
||||
|
||||
* **Node File System State [good_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is good
|
||||
|
||||
* **Node File System State [repairing_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is repairing
|
||||
|
||||
* **Node File System State [repairing]** [float]
|
||||
|
||||
The score to give when the state is repairing
|
||||
|
||||
* **Node File System State [restoring_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is restoring
|
||||
|
||||
* **Node File System State [restoring]** [float]
|
||||
|
||||
The score to give when the state is restoring
|
||||
|
||||
* **Node File System State [corrupt_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt]** [float]
|
||||
|
||||
The score to give when the state is corrupt
|
||||
|
||||
* **Node File System State [destroyed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed]** [float]
|
||||
|
||||
The score to give when the state is destroyed
|
||||
|
||||
* **Node File System State [scanning]** [float]
|
||||
|
||||
The score to give when the state is scanning
|
||||
|
||||
* **IER Status [red_ier_running]** [float]
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [green_ier_blocked]** [float]
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
**Patching / Reset Durations**
|
||||
|
||||
* **os_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching an Operating System
|
||||
|
||||
* **node_reset_duration** [int]
|
||||
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
|
||||
* **service_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching a service
|
||||
|
||||
* **file_system_repairing_limit** [int]:
|
||||
|
||||
The number of steps to take when repairing the file system
|
||||
|
||||
* **file_system_restoring_limit** [int]
|
||||
|
||||
The number of steps to take when restoring the file system
|
||||
|
||||
* **file_system_scanning_limit** [int]
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
* **deterministic** [bool]
|
||||
|
||||
Set to true if the agent evaluation should be deterministic. Default is ``False``
|
||||
|
||||
* **seed** [int]
|
||||
|
||||
Seed used in the randomisation in agent training. Default is ``None``
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
|
||||
The lay down config file consists of the following attributes:
|
||||
|
||||
|
||||
* **itemType: STEPS** [int]
|
||||
|
||||
* **item_type: PORTS** [int]
|
||||
|
||||
Provides a list of ports modelled in this session
|
||||
|
||||
* **item_type: SERVICES** [freetext]
|
||||
|
||||
Provides a list of services modelled in this session
|
||||
|
||||
* **item_type: NODE**
|
||||
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
* **node_class** [enum]: Relates to the base type of the node. Can be SERVICE, ACTIVE or PASSIVE. PASSIVE nodes do not have an operating system or services. ACTIVE nodes have an operating system, but no services. SERVICE nodes have both an operating system and one or more services
|
||||
* **node_type** [enum]: Relates to the component type. Can be one of CCTV, SWITCH, COMPUTER, LINK, MONITOR, PRINTER, LOP, RTU, ACTUATOR or SERVER
|
||||
* **priority** [enum]: Provides a priority for each node. Can be one of P1, P2, P3, P4 or P5 (which P1 being the highest)
|
||||
* **hardware_state** [enum]: The initial hardware state of the node. Can be one of ON, OFF or RESETTING
|
||||
* **ip_address** [IP address]: The IP address of the component in format xxx.xxx.xxx.xxx
|
||||
* **software_state** [enum]: The intial state of the node operating system. Can be GOOD, PATCHING or COMPROMISED
|
||||
* **file_system_state** [enum]: The initial state of the node file system. Can be GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING
|
||||
* **services**: For each service associated with the node:
|
||||
|
||||
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
|
||||
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
|
||||
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **item_type: LINK**
|
||||
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
* **bandwidth** [int]: The bandwidth (in bits/s) of the link
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
|
||||
* **item_type: GREEN_IER**
|
||||
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
* **end_step** [int]: The end step (in the episode) for this IER to finish
|
||||
* **load** [int]: The load (in bits/s) for this IER to apply to links
|
||||
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
|
||||
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
* **mission_criticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
|
||||
|
||||
* **item_type: RED_IER**
|
||||
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
* **end_step** [int]: The end step (in the episode) for this IER to finish
|
||||
* **load** [int]: The load (in bits/s) for this IER to apply to links
|
||||
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
|
||||
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
* **mission_criticality** [enum]: Not currently used. Default to 0
|
||||
|
||||
* **item_type: GREEN_POL**
|
||||
|
||||
Defines a green agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this PoL to begin
|
||||
* **end_step** [int]: Not currently used. Default to same as start step
|
||||
* **nodeId** [int]: The ID of the node to apply the PoL to
|
||||
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
|
||||
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
|
||||
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
|
||||
|
||||
* **item_type: RED_POL**
|
||||
|
||||
Defines a red agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this PoL to begin
|
||||
* **end_step** [int]: Not currently used. Default to same as start step
|
||||
* **targetNodeId** [int]: The ID of the node to apply the PoL to
|
||||
* **initiator** [enum]: What initiates the PoL. Can be DIRECT, IER or SERVICE
|
||||
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
|
||||
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
|
||||
* **state** [enum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state) or GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING (for file system state)
|
||||
* **sourceNodeId** [int] The ID of the source node containing the service to check (used for SERVICE initiator)
|
||||
* **sourceNodeService** [freetext]: The service on the source node to check (used for SERVICE initiator). Must match a value in the services list for this node
|
||||
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **item_type: ACL_RULE**
|
||||
|
||||
Defines an initial Access Control List (ACL) rule. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **permission** [enum]: Defines either an allow or deny rule. Value must be either DENY or ALLOW
|
||||
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
|
||||
@@ -1,489 +1,13 @@
|
||||
Primaite v3 config
|
||||
******************
|
||||
|
||||
PrimAITE uses a single configuration file to define a cybersecurity scenario. This includes the computer network and multiple agents. There are three main sections: training_config, game, and simulation.
|
||||
|
||||
The simulation section describes the simulated network environment with which the agetns interact.
|
||||
|
||||
The game section describes the agents and their capabilities. Each agent has a unique type and is associated with a team (GREEN, RED, or BLUE). Each agent has a configurable observation space, action space, and reward function.
|
||||
|
||||
The training_config section describes the training parameters for the learning agents. This includes the number of episodes, the number of steps per episode, and the number of steps before the agents start learning. The training_config section also describes the learning algorithm used by the agents. The learning algorithm is specified by the name of the algorithm and the hyperparameters for the algorithm. The hyperparameters are specific to each algorithm and are described in the documentation for each algorithm.
|
||||
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _config:
|
||||
|
||||
The Config Files Explained
|
||||
==========================
|
||||
|
||||
PrimAITE uses two configuration files for its operation:
|
||||
|
||||
* **The Training Config**
|
||||
|
||||
Used to define the top-level settings of the PrimAITE environment, the reward values, and the session that is to be run.
|
||||
|
||||
* **The Lay Down Config**
|
||||
|
||||
Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules.
|
||||
|
||||
Training Config:
|
||||
*******************
|
||||
|
||||
The Training Config file consists of the following attributes:
|
||||
|
||||
**Generic Config Values**
|
||||
|
||||
|
||||
* **agent_framework** [enum]
|
||||
|
||||
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
|
||||
|
||||
* NONE - Where a user developed agent is to be used
|
||||
* SB3 - Stable Baselines3
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **agent_identifier**
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* A2C - Advantage Actor Critic
|
||||
* PPO - Proximal Policy Optimization
|
||||
* HARDCODED - A custom built deterministic agent
|
||||
* RANDOM - A Stochastic random agent
|
||||
|
||||
|
||||
* **random_red_agent** [bool]
|
||||
|
||||
Determines if the session should be run with a random red agent
|
||||
|
||||
* **action_type** [enum]
|
||||
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
|
||||
|
||||
* **OBSERVATION_SPACE** [dict]
|
||||
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is in :py:mod:`primaite.environment.observations`.
|
||||
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
|
||||
This example illustrates the correct format for the observation space config item
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
combine_service_traffic : False
|
||||
quantisation_levels: 99
|
||||
|
||||
|
||||
Currently available components are:
|
||||
|
||||
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
|
||||
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
|
||||
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
|
||||
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
|
||||
|
||||
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
|
||||
* ``quantisation_levels`` - how many discrete bandwidth usage levels to use for encoding. This can be an integer equal to or greater than 3.
|
||||
|
||||
The other configurable item is ``flatten`` which is false by default. When set to true, the observation space is flattened (turned into a 1-D vector). You should use this if your RL agent does not natively support observation space types like ``gym.Spaces.Tuple``.
|
||||
|
||||
* **num_train_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will train for.
|
||||
|
||||
|
||||
* **num_train_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the training session.
|
||||
|
||||
|
||||
* **num_eval_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will be evaluated over.
|
||||
|
||||
|
||||
* **num_eval_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the evaluation session.
|
||||
|
||||
|
||||
* **time_delay** [int]
|
||||
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
|
||||
|
||||
* **session_type** [text]
|
||||
|
||||
Type of session to be run (TRAINING, EVALUATION, or BOTH)
|
||||
|
||||
* **load_agent** [bool]
|
||||
|
||||
Determine whether to load an agent from file
|
||||
|
||||
* **agent_load_file** [text]
|
||||
|
||||
File path and file name of agent if you're loading one in
|
||||
|
||||
* **observation_space_high_value** [int]
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
* **implicit_acl_rule** [str]
|
||||
|
||||
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
|
||||
|
||||
* **max_number_acl_rules** [int]
|
||||
|
||||
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
|
||||
* **Generic [all_ok]** [float]
|
||||
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
|
||||
* **Node Hardware State [off_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is off
|
||||
|
||||
* **Node Hardware State [off_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Hardware State [on_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Hardware State [on_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Hardware State [resetting_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting]** [float]
|
||||
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [float]
|
||||
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [float]
|
||||
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [float]
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [good_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is good
|
||||
|
||||
* **Node File System State [good_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is good
|
||||
|
||||
* **Node File System State [good_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is good
|
||||
|
||||
* **Node File System State [good_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is good
|
||||
|
||||
* **Node File System State [repairing_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is repairing
|
||||
|
||||
* **Node File System State [repairing]** [float]
|
||||
|
||||
The score to give when the state is repairing
|
||||
|
||||
* **Node File System State [restoring_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is restoring
|
||||
|
||||
* **Node File System State [restoring]** [float]
|
||||
|
||||
The score to give when the state is restoring
|
||||
|
||||
* **Node File System State [corrupt_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt]** [float]
|
||||
|
||||
The score to give when the state is corrupt
|
||||
|
||||
* **Node File System State [destroyed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed]** [float]
|
||||
|
||||
The score to give when the state is destroyed
|
||||
|
||||
* **Node File System State [scanning]** [float]
|
||||
|
||||
The score to give when the state is scanning
|
||||
|
||||
* **IER Status [red_ier_running]** [float]
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [green_ier_blocked]** [float]
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
**Patching / Reset Durations**
|
||||
|
||||
* **os_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching an Operating System
|
||||
|
||||
* **node_reset_duration** [int]
|
||||
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
|
||||
* **service_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching a service
|
||||
|
||||
* **file_system_repairing_limit** [int]:
|
||||
|
||||
The number of steps to take when repairing the file system
|
||||
|
||||
* **file_system_restoring_limit** [int]
|
||||
|
||||
The number of steps to take when restoring the file system
|
||||
|
||||
* **file_system_scanning_limit** [int]
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
* **deterministic** [bool]
|
||||
|
||||
Set to true if the agent evaluation should be deterministic. Default is ``False``
|
||||
|
||||
* **seed** [int]
|
||||
|
||||
Seed used in the randomisation in agent training. Default is ``None``
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
|
||||
The lay down config file consists of the following attributes:
|
||||
|
||||
|
||||
* **itemType: STEPS** [int]
|
||||
|
||||
* **item_type: PORTS** [int]
|
||||
|
||||
Provides a list of ports modelled in this session
|
||||
|
||||
* **item_type: SERVICES** [freetext]
|
||||
|
||||
Provides a list of services modelled in this session
|
||||
|
||||
* **item_type: NODE**
|
||||
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
* **node_class** [enum]: Relates to the base type of the node. Can be SERVICE, ACTIVE or PASSIVE. PASSIVE nodes do not have an operating system or services. ACTIVE nodes have an operating system, but no services. SERVICE nodes have both an operating system and one or more services
|
||||
* **node_type** [enum]: Relates to the component type. Can be one of CCTV, SWITCH, COMPUTER, LINK, MONITOR, PRINTER, LOP, RTU, ACTUATOR or SERVER
|
||||
* **priority** [enum]: Provides a priority for each node. Can be one of P1, P2, P3, P4 or P5 (which P1 being the highest)
|
||||
* **hardware_state** [enum]: The initial hardware state of the node. Can be one of ON, OFF or RESETTING
|
||||
* **ip_address** [IP address]: The IP address of the component in format xxx.xxx.xxx.xxx
|
||||
* **software_state** [enum]: The intial state of the node operating system. Can be GOOD, PATCHING or COMPROMISED
|
||||
* **file_system_state** [enum]: The initial state of the node file system. Can be GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING
|
||||
* **services**: For each service associated with the node:
|
||||
|
||||
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
|
||||
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
|
||||
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **item_type: LINK**
|
||||
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
* **bandwidth** [int]: The bandwidth (in bits/s) of the link
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
|
||||
* **item_type: GREEN_IER**
|
||||
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
* **end_step** [int]: The end step (in the episode) for this IER to finish
|
||||
* **load** [int]: The load (in bits/s) for this IER to apply to links
|
||||
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
|
||||
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
* **mission_criticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
|
||||
|
||||
* **item_type: RED_IER**
|
||||
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
* **end_step** [int]: The end step (in the episode) for this IER to finish
|
||||
* **load** [int]: The load (in bits/s) for this IER to apply to links
|
||||
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
|
||||
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
* **mission_criticality** [enum]: Not currently used. Default to 0
|
||||
|
||||
* **item_type: GREEN_POL**
|
||||
|
||||
Defines a green agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this PoL to begin
|
||||
* **end_step** [int]: Not currently used. Default to same as start step
|
||||
* **nodeId** [int]: The ID of the node to apply the PoL to
|
||||
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
|
||||
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
|
||||
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
|
||||
|
||||
* **item_type: RED_POL**
|
||||
|
||||
Defines a red agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this PoL to begin
|
||||
* **end_step** [int]: Not currently used. Default to same as start step
|
||||
* **targetNodeId** [int]: The ID of the node to apply the PoL to
|
||||
* **initiator** [enum]: What initiates the PoL. Can be DIRECT, IER or SERVICE
|
||||
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
|
||||
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
|
||||
* **state** [enum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state) or GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING (for file system state)
|
||||
* **sourceNodeId** [int] The ID of the source node containing the service to check (used for SERVICE initiator)
|
||||
* **sourceNodeService** [freetext]: The service on the source node to check (used for SERVICE initiator). Must match a value in the services list for this node
|
||||
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **item_type: ACL_RULE**
|
||||
|
||||
Defines an initial Access Control List (ACL) rule. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **permission** [enum]: Defines either an allow or deny rule. Value must be either DENY or ALLOW
|
||||
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
|
||||
This needs a bit of refactoring so I haven't written extensive documentation about the config yet.
|
||||
|
||||
48
docs/source/game_layer.rst
Normal file
48
docs/source/game_layer.rst
Normal file
@@ -0,0 +1,48 @@
|
||||
PrimAITE Game layer
|
||||
*******************
|
||||
|
||||
The Primaite codebase consists of two main modules:
|
||||
|
||||
* ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes.
|
||||
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules, including ARCD GATE.
|
||||
|
||||
These two components have been decoupled to allow the agent training code in ARCD GATE to be reused with other simulators. The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. The game layer communicates with ARCD gate using the `Farama Gymnasium Spaces API <https://gymnasium.farama.org/api/spaces/>`_.
|
||||
|
||||
..
|
||||
TODO: write up these APIs and link them here.
|
||||
|
||||
|
||||
Game layer
|
||||
----------
|
||||
|
||||
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
|
||||
|
||||
PrimAITE Session
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. It also sends messages to ARCD GATE to perform reinforcement learning. ``PrimaiteSession`` keeps track of multiple agents of different types.
|
||||
|
||||
Agents
|
||||
^^^^^^
|
||||
|
||||
All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types:
|
||||
* RL agents action during each step is decided by an RL algorithm which lives inside of ARCD GATE. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
|
||||
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable.
|
||||
|
||||
..
|
||||
TODO: add seed to stochastic scripted agents
|
||||
|
||||
Observations
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class.
|
||||
|
||||
Actions
|
||||
^^^^^^^
|
||||
|
||||
An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class.
|
||||
|
||||
Rewards
|
||||
^^^^^^^
|
||||
|
||||
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class.
|
||||
726
example_config.yaml
Normal file
726
example_config.yaml
Normal file
@@ -0,0 +1,726 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 20
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 20
|
||||
n_eval_steps: 128
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: GATERLAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
22
sandbox.py
Normal file
22
sandbox.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# flake8: noqa
|
||||
import logging
|
||||
|
||||
from primaite import _PRIMAITE_CONFIG, PRIMAITE_PATHS
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
_PRIMAITE_CONFIG["log_level"] = logging.DEBUG
|
||||
print(PRIMAITE_PATHS.app_log_dir_path)
|
||||
import itertools
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
with open("example_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
sess = PrimaiteSession.from_config(cfg)
|
||||
|
||||
sess.start_session()
|
||||
@@ -6,7 +6,7 @@ from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
|
||||
1
src/primaite/game/__init__.py
Normal file
1
src/primaite/game/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""PrimAITE Game Layer."""
|
||||
31
src/primaite/game/agent/GATE_agents.py
Normal file
31
src/primaite/game/agent/GATE_agents.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GATERLAgent(AbstractGATEAgent):
|
||||
...
|
||||
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
|
||||
# because when we are supporting MARL, the actions form multiple agents will have to be batched
|
||||
|
||||
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
|
||||
# with the actions for those agents.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None,
|
||||
action_space: ActionManager | None,
|
||||
observation_space: ObservationSpace | None,
|
||||
reward_function: RewardFunction | None,
|
||||
) -> None:
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.most_recent_action
|
||||
0
src/primaite/game/agent/__init__.py
Normal file
0
src/primaite/game/agent/__init__.py
Normal file
866
src/primaite/game/agent/actions.py
Normal file
866
src/primaite/game/agent/actions.py
Normal file
@@ -0,0 +1,866 @@
|
||||
"""
|
||||
This module contains the ActionManager class which belongs to the Agent class.
|
||||
|
||||
An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of
|
||||
AbstractAction. The ActionManager is responsible for:
|
||||
1. Creating the action space from a list of action types.
|
||||
2. Converting an integer action choice into a specific action and parameter choice.
|
||||
3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This
|
||||
ensures that requests conform to the simulator's request format.
|
||||
"""
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
"""
|
||||
Init method for action.
|
||||
|
||||
All action init functions should accept **kwargs as a way of ignoring extra arguments.
|
||||
|
||||
Since many parameters are defined for the action space as a whole (such as max files per folder, max services
|
||||
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
|
||||
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
|
||||
"""
|
||||
self.name: str = ""
|
||||
"""Human-readable action identifier used for printing, logging, and reporting."""
|
||||
self.shape: Dict[str, int] = {}
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
"""Action which does nothing. This is here to allow agents to be idle if they choose to."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.name = "DONOTHING"
|
||||
self.shape: Dict[str, int] = {
|
||||
"dummy": 1,
|
||||
}
|
||||
# This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1),
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self, **kwargs) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for service actions.
|
||||
|
||||
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
|
||||
if node_uuid is None or service_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "services", service_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
"""Action which scans a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
"""Action which stops a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
"""Action which starts a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
"""Action which pauses a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
"""Action which resumes a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
"""Action which restarts a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
"""Action which disables a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
"""Action which enables a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
|
||||
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
if node_uuid is None or folder_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
"""Action which checks the hash of a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
"""Action which repairs a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
"""Action which restores a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "restore"
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction):
|
||||
"""Abstract base class for file actions.
|
||||
|
||||
Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit
|
||||
from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, "files", file_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction):
|
||||
"""Action which scans a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction):
|
||||
"""Action which checks the hash of a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "checkhash"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
"""Action which deletes a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "delete"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
"""Action which repairs a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "repair"
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction):
|
||||
"""Action which restores a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction):
|
||||
"""Action which corrupts a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "corrupt"
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for node actions.
|
||||
|
||||
Any action which applies to a node and uses node_id as its only parameter can inherit from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
return ["network", "node", node_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction):
|
||||
"""Action which scans a node's OS."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction):
|
||||
"""Action which shuts down a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
"""Action which starts up a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
"""Action which resets a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "reset"
|
||||
|
||||
|
||||
class NetworkACLAddRuleAction(AbstractAction):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
target_router_uuid: str,
|
||||
max_acl_rules: int,
|
||||
num_ips: int,
|
||||
num_ports: int,
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for NetworkACLAddRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_uuid: UUID of the router to which the ACL rule should be added.
|
||||
:type target_router_uuid: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
:param num_ips: Number of IP addresses in the simulation.
|
||||
:type num_ips: int
|
||||
:param num_ports: Number of ports in the simulation.
|
||||
:type num_ports: int
|
||||
:param num_protocols: Number of protocols in the simulation.
|
||||
:type num_protocols: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 3
|
||||
self.shape: Dict[str, int] = {
|
||||
"position": max_acl_rules,
|
||||
"permission": num_permissions,
|
||||
"source_ip_id": num_ips,
|
||||
"dest_ip_id": num_ips,
|
||||
"source_port_id": num_ports,
|
||||
"dest_port_id": num_ports,
|
||||
"protocol_id": num_protocols,
|
||||
}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
dest_ip_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif permission == 1:
|
||||
permission_str = "ALLOW"
|
||||
elif permission == 2:
|
||||
permission_str = "DENY"
|
||||
else:
|
||||
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
if protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id in [0, 1]:
|
||||
src_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id in (0, 1):
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.target_router_uuid,
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission_str,
|
||||
protocol,
|
||||
src_ip,
|
||||
src_port,
|
||||
dst_ip,
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
|
||||
|
||||
class NetworkACLRemoveRuleAction(AbstractAction):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for NetworkACLRemoveRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_uuid: UUID of the router from which the ACL rule should be removed.
|
||||
:type target_router_uuid: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(self, position: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
|
||||
|
||||
|
||||
class NetworkNICAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
|
||||
class.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param num_nodes: Number of nodes in the simulation.
|
||||
:type num_nodes: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
|
||||
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
|
||||
if node_uuid is None or nic_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
node_uuid,
|
||||
"nic",
|
||||
nic_uuid,
|
||||
self.verb,
|
||||
]
|
||||
|
||||
|
||||
class NetworkNICEnableAction(NetworkNICAbstractAction):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
__act_class_identifiers: Dict[str, type] = {
|
||||
"DONOTHING": DoNothingAction,
|
||||
"NODE_SERVICE_SCAN": NodeServiceScanAction,
|
||||
"NODE_SERVICE_STOP": NodeServiceStopAction,
|
||||
"NODE_SERVICE_START": NodeServiceStartAction,
|
||||
"NODE_SERVICE_PAUSE": NodeServicePauseAction,
|
||||
"NODE_SERVICE_RESUME": NodeServiceResumeAction,
|
||||
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
"NODE_FILE_SCAN": NodeFileScanAction,
|
||||
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
"NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
"NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_FILE_CORRUPT": NodeFileCorruptAction,
|
||||
"NODE_FOLDER_SCAN": NodeFolderScanAction,
|
||||
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
"NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
|
||||
"NODE_OS_SCAN": NodeOSScanAction,
|
||||
"NODE_SHUTDOWN": NodeShutdownAction,
|
||||
"NODE_STARTUP": NodeStartupAction,
|
||||
"NODE_RESET": NodeResetAction,
|
||||
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: "PrimaiteSession", # reference to session for looking up stuff
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
max_files_per_folder: int = 2, # allows calculating shape
|
||||
max_services_per_node: int = 2, # allows calculating shape
|
||||
max_nics_per_node: int = 8, # allows calculating shape
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param session: Reference to the session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param actions: List of action types which should be made available to the agent.
|
||||
:type actions: List[str]
|
||||
:param node_uuids: List of node UUIDs that this agent can act on.
|
||||
:type node_uuids: List[str]
|
||||
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
|
||||
:type max_folders_per_node: int
|
||||
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
|
||||
:type max_files_per_folder: int
|
||||
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
|
||||
:type max_services_per_node: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
|
||||
:type max_nics_per_node: int
|
||||
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
|
||||
:type max_acl_rules: int
|
||||
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_address_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str]
|
||||
if ip_address_list is not None:
|
||||
self.ip_address_list = ip_address_list
|
||||
else:
|
||||
self.ip_address_list = []
|
||||
for node_uuid in self.node_uuids:
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = node_obj.nics
|
||||
for nic_uuid, nic_obj in nics.items():
|
||||
self.ip_address_list.append(nic_obj.ip_address)
|
||||
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(node_uuids),
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules,
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips": len(self.ip_address_list),
|
||||
"max_acl_rules": max_acl_rules,
|
||||
"max_nics_per_node": max_nics_per_node,
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_spec in actions:
|
||||
# each action is provided into the action space config like this:
|
||||
# - type: ACTION_TYPE
|
||||
# options:
|
||||
# option_1: value1
|
||||
# option_2: value2
|
||||
# where `type` decides which AbstractAction subclass should be used
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get("type")
|
||||
act_options = act_spec.get("options", {})
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
For example :
|
||||
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
|
||||
"""
|
||||
if act_map is None:
|
||||
self.action_map = self._enumerate_actions()
|
||||
else:
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Generate a list of all the possible actions that could be taken.
|
||||
|
||||
This enumerates all actions all combinations of parametes you could choose for those actions. The output
|
||||
of this function is intended to populate the self.action_map parameter in the situation where the user provides
|
||||
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
|
||||
|
||||
The enumeration relies on the Actions' `shape` attribute.
|
||||
|
||||
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
|
||||
An example output could be:
|
||||
{0: ("DONOTHING", {'dummy': 0}),
|
||||
1: ("NODE_OS_SCAN", {'node_id': 0}),
|
||||
2: ("NODE_OS_SCAN", {'node_id': 1}),
|
||||
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
|
||||
... #etc...
|
||||
}
|
||||
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
|
||||
"""
|
||||
all_action_possibilities = []
|
||||
for act_name, action in self.actions.items():
|
||||
param_names = list(action.shape.keys())
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
|
||||
for j in range(len(param_combinations))
|
||||
]
|
||||
)
|
||||
|
||||
return {i: p for i, p in enumerate(all_action_possibilities)}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format."""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
def get_node_uuid_by_idx(self, node_idx: int) -> str:
|
||||
"""
|
||||
Get the node UUID corresponding to the given index.
|
||||
|
||||
:param node_idx: The index of the node to retrieve.
|
||||
:type node_idx: int
|
||||
:return: The node UUID.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.node_uuids[node_idx]
|
||||
|
||||
def get_folder_uuid_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
"""
|
||||
Get the folder UUID corresponding to the given node and folder indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:return: The UUID of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
|
||||
|
||||
def get_file_uuid_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
"""Get the file UUID corresponding to the given node, folder, and file indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:param file_idx: The index of the file in the folder.
|
||||
:type file_idx: int
|
||||
:return: The UUID of the file. Or None if the node has fewer folders than the given index, or the folder has
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
if len(folder_uuids) <= folder_idx:
|
||||
return None
|
||||
folder = node.file_system.folders[folder_uuids[folder_idx]]
|
||||
file_uuids = list(folder.files.keys())
|
||||
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
|
||||
|
||||
def get_service_uuid_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
"""Get the service UUID corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param service_idx: The index of the service on the node.
|
||||
:type service_idx: int
|
||||
:return: The UUID of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
service_uuids = list(node.services.keys())
|
||||
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
"""Get the internet protocol corresponding to the given index.
|
||||
|
||||
:param protocol_idx: The index of the protocol to retrieve.
|
||||
:type protocol_idx: int
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
"""
|
||||
Get the IP address corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP address to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
|
||||
:param port_idx: The index of the port to retrieve.
|
||||
:type port_idx: int
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
|
||||
"""
|
||||
Get the NIC UUID corresponding to the given node and NIC indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param nic_idx: The index of the NIC on the node.
|
||||
:type nic_idx: int
|
||||
:return: The NIC UUID.
|
||||
:rtype: str
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = list(node_obj.nics.keys())
|
||||
if len(nics) <= nic_idx:
|
||||
return None
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
The action space config supports the following three sections:
|
||||
1. ``action_list``
|
||||
``action_list`` contians a list action components which need to be included in the action space.
|
||||
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
|
||||
which will be passed to the action class's __init__ method during initialisation.
|
||||
2. ``action_map``
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param session: The Primaite Session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
obj = cls(
|
||||
session=session,
|
||||
actions=cfg["action_list"],
|
||||
# node_uuids=cfg["options"]["node_uuids"],
|
||||
**cfg["options"],
|
||||
protocols=session.options.protocols,
|
||||
ports=session.options.ports,
|
||||
ip_address_list=None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
return obj
|
||||
116
src/primaite/game/agent/interface.py
Normal file
116
src/primaite/game/agent/interface.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
|
||||
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
:type agent_name: Optional[str]
|
||||
:param action_space: Action space for the agent.
|
||||
:type action_space: Optional[ActionManager]
|
||||
:param observation_space: Observation space for the agent.
|
||||
:type observation_space: Optional[ObservationSpace]
|
||||
:param reward_function: Reward function for the agent.
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
|
||||
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
"""
|
||||
Use the reward function to calculate a reward from the state.
|
||||
|
||||
:param state: State of the environment.
|
||||
:type state: Dict
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.calculate(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
Subclasses should implement agent logic here. It should use the observation as input to decide best next action.
|
||||
|
||||
:param obs: Observation of the environment.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
|
||||
...
|
||||
984
src/primaite/game/agent/observations.py
Normal file
984
src/primaite/game/agent/observations.py
Normal file
@@ -0,0 +1,984 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Return an observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
|
||||
a subclass of this class may need to translate from a 'reference' to a UUID.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
"""Observation of a file on a node in the network."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""
|
||||
Initialise file observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["health_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param session: _description_
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
"""Observation of a service in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise service observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_state"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 10) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": utilisation_category}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
"""Folder observation, including files inside of the folder."""
|
||||
|
||||
def __init__(
|
||||
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
|
||||
) -> None:
|
||||
"""Initialise folder Observation, including files inside of the folder.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files_per_folder:
|
||||
self.files.append(FileObservation())
|
||||
while len(self.files) > num_files_per_folder:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_uuid>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed folder observation
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"nic_status": 0}
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
return {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
|
||||
) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
nics: List[NicObservation] = [],
|
||||
logon_status: bool = False,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service UUID, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
|
||||
:type nics: Dict[int,str], optional
|
||||
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services_per_node:
|
||||
# add empty service observation without `where` parameter so it always returns default (blank) observation
|
||||
self.services.append(ServiceObservation())
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warn(msg)
|
||||
# truncate service list
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.nics: List[NicObservation] = nics
|
||||
while len(self.nics) < num_nics_per_node:
|
||||
self.nics.append(NicObservation())
|
||||
while len(self.nics) > num_nics_per_node:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
|
||||
if self.logon_status:
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
session: "PrimaiteSession",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_services_per_node: int, optional
|
||||
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_folders_per_node: int, optional
|
||||
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_uuid = session.ref_map_nodes[config["node_ref"]]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_uuid]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_uuid]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
|
||||
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
folders=folders,
|
||||
nics=nics,
|
||||
logon_status=logon_status,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: list[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_uuid>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
if rule_state is None:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_ref"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.ethernet_port[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
"""Container for all observations used by the blue agent in UC2.
|
||||
|
||||
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
|
||||
for the purpose of compiling several observation components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialise UC2 blue observation.
|
||||
|
||||
:param nodes: List of node observations
|
||||
:type nodes: List[NodeObservation]
|
||||
:param links: List of link observations
|
||||
:type links: List[LinkObservation]
|
||||
:param acl: The Access Control List observation
|
||||
:type acl: AclObservation
|
||||
:param ics: The ICS observation
|
||||
:type ics: ICSObservation
|
||||
:param where: Where in the simulation state dict to find information. Not used in this particular observation
|
||||
because it only compiles other observations and doesn't contribute any new information, defaults to None
|
||||
:type where: Optional[List[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
session=session,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, session=session)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, session=session)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
"""Container for all observations used by the red agent in UC2."""
|
||||
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation):
|
||||
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
The observation space has the purpose of:
|
||||
1. Reading the outputted state from the PrimAITE Simulation.
|
||||
2. Selecting parts of the simulation state that are requested by the simulation config
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
|
||||
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
|
||||
# refactor.
|
||||
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
"""
|
||||
return self.obs.observe(state)
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
284
src/primaite/game/agent/rewards.py
Normal file
284
src/primaite/game/agent/rewards.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Manages the reward function for the agent.
|
||||
|
||||
Each agent is equipped with a RewardFunction, which is made up of a list of reward components. The components are
|
||||
designed to calculate a reward value based on the current state of the simulation. The overall reward function is a
|
||||
weighed sum of the components.
|
||||
|
||||
The reward function is typically specified using a config yaml file or a config dictionary. The following example shows
|
||||
the structure:
|
||||
```yaml
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_database_client
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_uuid: UUID of the node which contains the database file.
|
||||
:type node_uuid: str
|
||||
:param folder_name: folder which contains the database file.
|
||||
:type folder_name: str
|
||||
:param file_name: name of the database file.
|
||||
:type file_name: str
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
node_uuid,
|
||||
"file_system",
|
||||
"folders",
|
||||
folder_name,
|
||||
"files",
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
health_status = database_file_state["health_status"]
|
||||
if health_status == "corrupted":
|
||||
return -1
|
||||
elif health_status == "good":
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
node_ref = config.get("node_ref")
|
||||
folder_name = config.get("folder_name")
|
||||
file_name = config.get("file_name")
|
||||
if not node_ref:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not folder_name:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not file_name:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(
|
||||
(
|
||||
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
|
||||
f"found in the simulation"
|
||||
)
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
|
||||
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_uuid: str, service_uuid: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_uuid: UUID of the node which contains the web server service.
|
||||
:type node_uuid: str
|
||||
:param service_uuid: UUID of the web server service.
|
||||
:type service_uuid: str
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_service_state is NOT_PRESENT_IN_STATE:
|
||||
print("error getting web service state")
|
||||
return 0.0
|
||||
most_recent_return_code = web_service_state["last_response_status_code"]
|
||||
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
|
||||
if most_recent_return_code == 200:
|
||||
return 1.0
|
||||
elif most_recent_return_code == 404:
|
||||
return -1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
node_ref = config.get("node_ref")
|
||||
service_ref = config.get("service_ref")
|
||||
if not (node_ref and service_ref):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
|
||||
"found in reward config."
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
service_uuid = session.ref_map_services[service_ref].uuid
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
|
||||
" found in the simulator."
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: consider erroring here as well
|
||||
|
||||
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
|
||||
:param component: Instance of a reward component.
|
||||
:type component: AbstractReward
|
||||
:param weight: Relative weight of the reward component, defaults to 1.0
|
||||
:type weight: float, optional
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
total = 0.0
|
||||
for comp_and_weight in self.reward_components:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
return total
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward manager.
|
||||
:rtype: RewardFunction
|
||||
"""
|
||||
new = cls()
|
||||
|
||||
for rew_component_cfg in config["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.__rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
|
||||
new.regsiter_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
14
src/primaite/game/agent/scripted_agents.py
Normal file
14
src/primaite/game/agent/scripted_agents.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Agents with predefined behaviours."""
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class GreenWebBrowsingAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to send web requests to a target node."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedDatabaseCorruptingAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to corrupt the database of the target node."""
|
||||
|
||||
raise NotImplementedError
|
||||
30
src/primaite/game/agent/utils.py
Normal file
30
src/primaite/game/agent/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any, Dict, Hashable, Sequence
|
||||
|
||||
NOT_PRESENT_IN_STATE = object()
|
||||
"""
|
||||
Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes
|
||||
the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose.
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
"""
|
||||
Access an item from a deeply dictionary with a list of keys.
|
||||
|
||||
For example, if the dictionary is {1: 'a', 2: {3: {4: 'b'}}}, then the key [2, 3, 4] would return 'b', and the key
|
||||
[2, 3] would return {4: 'b'}. Raises a KeyError if specified key does not exist at any level of nesting.
|
||||
|
||||
:param dictionary: Deeply nested dictionary
|
||||
:type dictionary: Dict
|
||||
:param keys: List of dict keys used to traverse the nested dict. Each item corresponds to one level of depth.
|
||||
:type keys: List[Hashable]
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
k = key_list.pop(0)
|
||||
if k not in dictionary:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
return access_from_nested_dict(dictionary[k], key_list)
|
||||
471
src/primaite/game/session.py
Normal file
471
src/primaite/game/session.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from arcd_gate.client.gate_client import ActType, GATEClient
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces.utils import flatten, flatten_space
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
|
||||
|
||||
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
|
||||
"""
|
||||
Create a new GATE client for PrimAITE.
|
||||
|
||||
:param parent_session: The parent session object.
|
||||
:type parent_session: PrimaiteSession
|
||||
:param service_port: The port on which the GATE service is running.
|
||||
:type service_port: int, optional
|
||||
"""
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session: "PrimaiteSession" = parent_session
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
"""The reinforcement learning framework to use."""
|
||||
return self.parent_session.training_options.rl_framework
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
"""The reinforcement learning algorithm to use."""
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> int | None:
|
||||
"""The seed to use for the environment's random number generator."""
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
"""The number of episodes in each learning run."""
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
"""The number of steps in each learning episode."""
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
"""The number of episodes in each evaluation run."""
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
"""The number of steps in each evaluation episode."""
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
"""The gym action space of the agent."""
|
||||
return self.parent_session.rl_agent.action_space.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""The gymnasium observation space of the agent."""
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
"""Take a step in the environment.
|
||||
|
||||
This method is called by GATE to advance the simulation by one timestep.
|
||||
|
||||
:param action: The agent's action.
|
||||
:type action: ActType
|
||||
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
|
||||
:rtype: Tuple[ObsType, float, bool, bool, Dict]
|
||||
"""
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate(state)
|
||||
term = False
|
||||
trunc = False
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment.
|
||||
|
||||
This method is called when the environment is initialized and at the end of each episode.
|
||||
|
||||
:param seed: The seed to use for the environment's random number generator.
|
||||
:type seed: int, optional
|
||||
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
|
||||
compatibility with GATE.
|
||||
:type options: dict[str, Any], optional
|
||||
:return: The initial observation and an empty info dictionary.
|
||||
:rtype: Tuple[ObsType, Dict]
|
||||
"""
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
return obs, {}
|
||||
|
||||
def close(self):
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
self.parent_session.close()
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
|
||||
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
|
||||
"""
|
||||
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
rl_framework: str
|
||||
rl_algorithm: str
|
||||
seed: Optional[int]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int
|
||||
n_eval_steps: int
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
|
||||
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.rl_agent: AbstractAgent
|
||||
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.options: PrimaiteSessionOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
self.ref_map_applications: Dict[str, Application] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
self.gate_client.start()
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
|
||||
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
|
||||
agent. The steps are as follows:
|
||||
1. The simulation state is updated.
|
||||
2. The simulation state is sent to each agent.
|
||||
3. Each agent converts the state to an observation and calculates a reward.
|
||||
4. Each agent chooses an action based on the observation.
|
||||
5. Each agent converts the action to a request.
|
||||
6. The simulation applies the requests.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
# 5. primaite session asks agents to produce an action based on most recent state
|
||||
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
|
||||
sim_state = self.simulation.describe_state()
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.convert_state_to_obs(sim_state)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.calculate_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
_LOGGER.debug("Getting agent action")
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
|
||||
self.simulation.apply_request(agent_request)
|
||||
|
||||
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
self.step_counter += 1
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent. Used by GATE.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
# TODO: create documentation page and add links to it here.
|
||||
|
||||
:param cfg: The config dictionary.
|
||||
:type cfg: dict
|
||||
:return: A PrimaiteSession object.
|
||||
:rtype: PrimaiteSession
|
||||
"""
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
sess.ref_map_nodes: Dict[str, Node] = {}
|
||||
sess.ref_map_services: Dict[str, Service] = {}
|
||||
sess.ref_map_links: Dict[str, Link] = {}
|
||||
|
||||
nodes_cfg = cfg["simulation"]["network"]["nodes"]
|
||||
links_cfg = cfg["simulation"]["network"]["links"]
|
||||
for node_cfg in nodes_cfg:
|
||||
node_ref = node_cfg["ref"]
|
||||
n_type = node_cfg["type"]
|
||||
if n_type == "computer":
|
||||
new_node = Computer(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
default_gateway=node_cfg["default_gateway"],
|
||||
dns_server=node_cfg["dns_server"],
|
||||
)
|
||||
elif n_type == "server":
|
||||
new_node = Server(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
default_gateway=node_cfg["default_gateway"],
|
||||
dns_server=node_cfg.get("dns_server"),
|
||||
)
|
||||
elif n_type == "switch":
|
||||
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
|
||||
elif n_type == "router":
|
||||
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
|
||||
if "ports" in node_cfg:
|
||||
for port_num, port_cfg in node_cfg["ports"].items():
|
||||
new_node.configure_port(
|
||||
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
|
||||
)
|
||||
if "acl" in node_cfg:
|
||||
for r_num, r_cfg in node_cfg["acl"].items():
|
||||
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
|
||||
# this: 'r_cfg.get('src_port')'
|
||||
# Port/IPProtocol. TODO Refactor
|
||||
new_node.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("ip_address"),
|
||||
dst_ip_address=r_cfg.get("ip_address"),
|
||||
position=r_num,
|
||||
)
|
||||
else:
|
||||
print("invalid node type")
|
||||
if "services" in node_cfg:
|
||||
for service_cfg in node_cfg["services"]:
|
||||
service_ref = service_cfg["ref"]
|
||||
service_type = service_cfg["type"]
|
||||
service_types_mapping = {
|
||||
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
|
||||
"DNSServer": DNSServer,
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DatabaseService": DatabaseService,
|
||||
"WebServer": WebServer,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
}
|
||||
if service_type in service_types_mapping:
|
||||
print(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
sess.ref_map_services[service_ref] = new_service
|
||||
else:
|
||||
print(f"service type not found {service_type}")
|
||||
# service-dependent options
|
||||
if service_type == "DatabaseClient":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "db_server_ip" in opt:
|
||||
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
|
||||
if service_type == "DNSServer":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "domain_mapping" in opt:
|
||||
for domain, ip in opt["domain_mapping"].items():
|
||||
new_service.dns_register(domain, ip)
|
||||
if "applications" in node_cfg:
|
||||
for application_cfg in node_cfg["applications"]:
|
||||
application_ref = application_cfg["ref"]
|
||||
application_type = application_cfg["type"]
|
||||
application_types_mapping = {
|
||||
"WebBrowser": WebBrowser,
|
||||
}
|
||||
if application_type in application_types_mapping:
|
||||
new_node.software_manager.install(application_types_mapping[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
sess.ref_map_applications[application_ref] = new_application
|
||||
else:
|
||||
print(f"application type not found {application_type}")
|
||||
if "nics" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["nics"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
sess.ref_map_nodes[
|
||||
node_ref
|
||||
] = (
|
||||
new_node.uuid
|
||||
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
|
||||
if isinstance(node_b, Switch):
|
||||
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
|
||||
else:
|
||||
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
|
||||
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
|
||||
# 3. create agents
|
||||
game_cfg = cfg["game_config"]
|
||||
agents_cfg = game_cfg["agents"]
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"] # noqa: F841
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
|
||||
if "node_ref" in action_node_option:
|
||||
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
|
||||
action_space_cfg["options"]["node_uuids"].append(node_uuid)
|
||||
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
|
||||
# we will pass node_uuids as a part of the action space config.
|
||||
# However, it's not possible to specify the node uuids directly in the config, as they are generated
|
||||
# dynamically, so we have to translate node references to uuids before passing this config on.
|
||||
|
||||
if "action_list" in action_space_cfg:
|
||||
for action_config in action_space_cfg["action_list"]:
|
||||
if "options" in action_config:
|
||||
if "target_router_ref" in action_config["options"]:
|
||||
_target = action_config["options"]["target_router_ref"]
|
||||
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
|
||||
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
return sess
|
||||
@@ -47,21 +47,21 @@ def convert_size(size_bytes: int) -> str:
|
||||
|
||||
|
||||
class FileSystemItemHealthStatus(Enum):
|
||||
"""Status of the FileSystemItem."""
|
||||
"""Health status for folders and files."""
|
||||
|
||||
GOOD = 0
|
||||
GOOD = 1
|
||||
"""File/Folder is OK."""
|
||||
|
||||
COMPROMISED = 1
|
||||
COMPROMISED = 2
|
||||
"""File/Folder is quarantined."""
|
||||
|
||||
CORRUPT = 2
|
||||
CORRUPT = 3
|
||||
"""File/Folder is corrupted."""
|
||||
|
||||
RESTORING = 3
|
||||
RESTORING = 4
|
||||
"""File/Folder is in the process of being restored."""
|
||||
|
||||
REPAIRING = 3
|
||||
REPAIRING = 5
|
||||
"""File/Folder is in the process of being repaired."""
|
||||
|
||||
|
||||
@@ -92,8 +92,8 @@ class FileSystemItemABC(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["name"] = self.name
|
||||
state["status"] = self.health_status.name
|
||||
state["visible_status"] = self.visible_health_status.name
|
||||
state["status"] = self.health_status.value
|
||||
state["visible_status"] = self.visible_health_status.value
|
||||
state["previous_hash"] = self.previous_hash
|
||||
return state
|
||||
|
||||
|
||||
@@ -160,8 +160,8 @@ class Network(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"nodes": {i for i, node in self._node_id_map.items()},
|
||||
"links": {i: link.describe_state() for i, link in self._link_id_map.items()},
|
||||
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
|
||||
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
|
||||
}
|
||||
)
|
||||
return state
|
||||
@@ -218,7 +218,9 @@ class Network(SimComponent):
|
||||
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
|
||||
self._node_request_manager.remove_request(name=node.uuid)
|
||||
|
||||
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
|
||||
def connect(
|
||||
self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs
|
||||
) -> Optional[Link]:
|
||||
"""
|
||||
Connect two endpoints on the network by creating a link between their NICs/SwitchPorts.
|
||||
|
||||
@@ -245,6 +247,7 @@ class Network(SimComponent):
|
||||
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
|
||||
link.parent = self
|
||||
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
|
||||
return link
|
||||
|
||||
def remove_link(self, link: Link) -> None:
|
||||
"""Disconnect a link from the network.
|
||||
|
||||
@@ -859,14 +859,14 @@ class ICMP:
|
||||
class NodeOperatingState(Enum):
|
||||
"""Enumeration of Node Operating States."""
|
||||
|
||||
OFF = 0
|
||||
"The node is powered off."
|
||||
ON = 1
|
||||
"The node is powered on."
|
||||
SHUTTING_DOWN = 2
|
||||
"The node is in the process of shutting down."
|
||||
OFF = 2
|
||||
"The node is powered off."
|
||||
BOOTING = 3
|
||||
"The node is in the process of booting up."
|
||||
SHUTTING_DOWN = 4
|
||||
"The node is in the process of shutting down."
|
||||
|
||||
|
||||
class Node(SimComponent):
|
||||
@@ -950,6 +950,7 @@ class Node(SimComponent):
|
||||
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
|
||||
if not kwargs.get("software_manager"):
|
||||
kwargs["software_manager"] = SoftwareManager(
|
||||
parent_node=self,
|
||||
sys_log=kwargs.get("sys_log"),
|
||||
session_manager=kwargs.get("session_manager"),
|
||||
file_system=kwargs.get("file_system"),
|
||||
@@ -1252,7 +1253,8 @@ class Node(SimComponent):
|
||||
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
"""Uninstall and completely remove service from this node.
|
||||
"""
|
||||
Uninstall and completely remove service from this node.
|
||||
|
||||
:param service: Service object that is currently associated with this node.
|
||||
:type service: Service
|
||||
@@ -1267,6 +1269,38 @@ class Node(SimComponent):
|
||||
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
|
||||
self._service_request_manager.remove_request(service.uuid)
|
||||
|
||||
def install_application(self, application: Application) -> None:
|
||||
"""
|
||||
Install an application on this node.
|
||||
|
||||
:param application: Application instance that has not been installed on any node yet.
|
||||
:type application: Application
|
||||
"""
|
||||
if application in self:
|
||||
_LOGGER.warning(f"Can't add application {application.uuid} to node {self.uuid}. It's already installed.")
|
||||
return
|
||||
self.applications[application.uuid] = application
|
||||
application.parent = self
|
||||
self.sys_log.info(f"Installed application {application.name}")
|
||||
_LOGGER.info(f"Added application {application.uuid} to node {self.uuid}")
|
||||
self._application_request_manager.add_request(application.uuid, RequestType(func=application._request_manager))
|
||||
|
||||
def uninstall_application(self, application: Application) -> None:
|
||||
"""
|
||||
Uninstall and completely remove application from this node.
|
||||
|
||||
:param application: Application object that is currently associated with this node.
|
||||
:type application: Application
|
||||
"""
|
||||
if application not in self:
|
||||
_LOGGER.warning(f"Can't remove application {application.uuid} from node {self.uuid}. It's not installed.")
|
||||
return
|
||||
self.applications.pop(application.uuid)
|
||||
application.parent = None
|
||||
self.sys_log.info(f"Uninstalled application {application.name}")
|
||||
_LOGGER.info(f"Removed application {application.uuid} from node {self.uuid}")
|
||||
self._application_request_manager.remove_request(application.uuid)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
if isinstance(item, Service):
|
||||
return item.uuid in self.services
|
||||
|
||||
@@ -58,7 +58,14 @@ class ACLRule(SimComponent):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
class AccessControlList(SimComponent):
|
||||
@@ -104,11 +111,11 @@ class AccessControlList(SimComponent):
|
||||
RequestType(
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
IPProtocol[request[1]],
|
||||
IPv4Address[request[2]],
|
||||
Port[request[3]],
|
||||
IPv4Address[request[4]],
|
||||
Port[request[5]],
|
||||
None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
IPv4Address(request[2]),
|
||||
None if request[3] == "ALL" else Port[request[3]],
|
||||
IPv4Address(request[4]),
|
||||
None if request[5] == "ALL" else Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
),
|
||||
@@ -123,7 +130,12 @@ class AccessControlList(SimComponent):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["implicit_action"] = self.implicit_action.value
|
||||
state["implicit_rule"] = self.implicit_rule.describe_state()
|
||||
state["max_acl_rules"] = self.max_acl_rules
|
||||
state["acl"] = {i: r.describe_state() if isinstance(r, ACLRule) else None for i, r in enumerate(self._acl)}
|
||||
return state
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Optional[ACLRule]]:
|
||||
@@ -648,7 +660,10 @@ class Router(Node):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
|
||||
@@ -55,12 +55,11 @@ class Switch(Node):
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
"""
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"num_ports": self.num_ports, # redundant?
|
||||
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
|
||||
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
|
||||
}
|
||||
state = super().describe_state()
|
||||
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
|
||||
state["num_ports"] = self.num_ports # redundant?
|
||||
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
|
||||
return state
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
"""
|
||||
|
||||
@@ -27,6 +27,7 @@ class Simulation(SimComponent):
|
||||
rm.add_request("network", RequestType(func=self.network._request_manager))
|
||||
# pass through domain requests to the domain object
|
||||
rm.add_request("domain", RequestType(func=self.domain._request_manager))
|
||||
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -45,7 +45,7 @@ class Application(IOSoftware):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"opearting_state": self.operating_state.name,
|
||||
"opearting_state": self.operating_state.value,
|
||||
"execution_control_status": self.execution_control_status,
|
||||
"num_executions": self.num_executions,
|
||||
"groups": list(self.groups),
|
||||
|
||||
@@ -38,7 +38,8 @@ class WebBrowser(Application):
|
||||
|
||||
:return: A dictionary capturing the current state of the WebBrowser and its child objects.
|
||||
"""
|
||||
return super().describe_state()
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
|
||||
@@ -14,6 +14,7 @@ from primaite.simulator.system.software import IOSoftware
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
from typing import Type, TypeVar
|
||||
|
||||
@@ -25,6 +26,7 @@ class SoftwareManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_node: "Node",
|
||||
session_manager: "SessionManager",
|
||||
sys_log: SysLog,
|
||||
file_system: FileSystem,
|
||||
@@ -35,6 +37,7 @@ class SoftwareManager:
|
||||
|
||||
:param session_manager: The session manager handling network communications.
|
||||
"""
|
||||
self.node = parent_node
|
||||
self.session_manager = session_manager
|
||||
self.software: Dict[str, Union[Service, Application]] = {}
|
||||
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
|
||||
@@ -62,6 +65,8 @@ class SoftwareManager:
|
||||
|
||||
:param software_class: The software class.
|
||||
"""
|
||||
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
|
||||
# separation of concerns.
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
@@ -77,6 +82,12 @@ class SoftwareManager:
|
||||
if isinstance(software, Application):
|
||||
software.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
# add the software to the node's registry after it has been fully initialized
|
||||
if isinstance(software, Service):
|
||||
self.node.install_service(software)
|
||||
elif isinstance(software, Application):
|
||||
self.node.install_application(software)
|
||||
|
||||
def uninstall(self, software_name: str):
|
||||
"""
|
||||
Uninstall an Application or Service.
|
||||
@@ -85,6 +96,10 @@ class SoftwareManager:
|
||||
"""
|
||||
if software_name in self.software:
|
||||
software = self.software.pop(software_name) # noqa
|
||||
if isinstance(software, Application):
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
return
|
||||
|
||||
@@ -15,14 +15,14 @@ class ServiceOperatingState(Enum):
|
||||
"The service is currently running."
|
||||
STOPPED = 2
|
||||
"The service is not running."
|
||||
INSTALLING = 3
|
||||
"The service is being installed or updated."
|
||||
RESTARTING = 4
|
||||
"The service is in the process of restarting."
|
||||
PAUSED = 5
|
||||
PAUSED = 3
|
||||
"The service is temporarily paused."
|
||||
DISABLED = 6
|
||||
DISABLED = 4
|
||||
"The service is disabled and cannot be started."
|
||||
INSTALLING = 5
|
||||
"The service is being installed or updated."
|
||||
RESTARTING = 6
|
||||
"The service is in the process of restarting."
|
||||
|
||||
|
||||
class Service(IOSoftware):
|
||||
@@ -68,7 +68,7 @@ class Service(IOSoftware):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["operating_state"] = self.operating_state.name
|
||||
state["operating_state"] = self.operating_state.value
|
||||
state["health_state_actual"] = self.health_state_actual
|
||||
state["health_state_visible"] = self.health_state_visible
|
||||
return state
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
@@ -17,6 +17,23 @@ from primaite.simulator.system.services.service import Service
|
||||
class WebServer(Service):
|
||||
"""Class used to represent a Web Server Service in simulation."""
|
||||
|
||||
last_response_status_code: Optional[HttpStatusCode] = None
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = (
|
||||
self.last_response_status_code.value if self.last_response_status_code else None
|
||||
)
|
||||
return state
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "WebServer"
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
@@ -66,6 +83,7 @@ class WebServer(Service):
|
||||
self.send(payload=response, session_id=session_id)
|
||||
|
||||
# return true if response is OK
|
||||
self.last_response_status_code = response.status_code
|
||||
return response.status_code == HttpStatusCode.OK
|
||||
|
||||
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:
|
||||
|
||||
@@ -121,9 +121,9 @@ class Software(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"health_state": self.health_state_actual.name,
|
||||
"health_state_red_view": self.health_state_visible.name,
|
||||
"criticality": self.criticality.name,
|
||||
"health_state": self.health_state_actual.value,
|
||||
"health_state_red_view": self.health_state_visible.value,
|
||||
"criticality": self.criticality.value,
|
||||
"patching_count": self.patching_count,
|
||||
"scanning_count": self.scanning_count,
|
||||
"revealed_to_red": self.revealed_to_red,
|
||||
|
||||
@@ -7,7 +7,7 @@ from primaite.common.enums import AgentIdentifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
|
||||
5
src/primaite/utils/start_gate_server.py
Normal file
5
src/primaite/utils/start_gate_server.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
service = GATEService()
|
||||
service.start()
|
||||
20
tests/integration_tests/game_layer/test_observations.py
Normal file
20
tests/integration_tests/game_layer/test_observations.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from gym import spaces
|
||||
|
||||
from primaite.game.agent.observations import FileObservation
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
def test_file_observation():
|
||||
sim = Simulation()
|
||||
pc = Computer(hostname="beep", ip_address="123.123.123.123", subnet_mask="255.255.255.0")
|
||||
sim.network.add_node(pc)
|
||||
f = pc.file_system.create_file(file_name="dog.png")
|
||||
|
||||
state = sim.describe_state()
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.uuid, "file_system", "folders", "root", "files", "dog.png"]
|
||||
)
|
||||
assert dog_file_obs.observe(state) == {"health_status": 1}
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
Reference in New Issue
Block a user