diff --git a/.gitignore b/.gitignore index ce42d9a9..a81c8ee1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,3 @@ -# PrimAITE Package -src/primaite/outputs -src/primaite/outputs/* -src/primaite/logs -src/primaite/logs/* -TestResults - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/MANIFEST.in b/MANIFEST.in index da226e04..04a90e0e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -include src/primaite/config/*.yaml +include src/primaite/config/_package_data/*.yaml diff --git a/docs/index.rst b/docs/index.rst index b92d493e..1da23718 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,15 +9,15 @@ Welcome to PrimAITE's documentation What is PrimAITE? ------------------------ -PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme. It incorporates the functionality required of a Primary-level environment, as specified in the Dstl ARCD Training Environment Matrix document:​ +PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme. It incorporates the functionality required of a Primary-level environment, as specified in the Dstl ARCD Training Environment Matrix document: * The ability to model a relevant platform / system context; * The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, traffic loading, operating systems, file system, services and processes; * Operates at machine-speed to enable fast training cycles. -PrimAITE aims to evolve into an ARCD environment that could be used as the follow-on from Reception level approaches (e.g. YAWNING TITAN), and help bridge the Sim-to-Real gap into Secondary level environments (e.g. IMAGINARY YAK)​. +PrimAITE aims to evolve into an ARCD environment that could be used as the follow-on from Reception level approaches (e.g. YAWNING TITAN), and help bridge the Sim-to-Real gap into Secondary level environments (e.g. IMAGINARY YAK). -This is similar to the approach taken by FVEY international partners (e.g. AUS CyBORG, US NSA FARLAND and CAN CyGil). These environments are referenced by the Dstl ARCD Agent Training Environments Knowledge Transfer document (TR141342)​. +This is similar to the approach taken by FVEY international partners (e.g. AUS CyBORG, US NSA FARLAND and CAN CyGil). These environments are referenced by the Dstl ARCD Agent Training Environments Knowledge Transfer document (TR141342). What is PrimAITE built with -------------------------------------- diff --git a/docs/source/about.rst b/docs/source/about.rst index 8cc08b13..180d0549 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -8,51 +8,51 @@ Features PrimAITE provides the following features: -* A flexible network / system laydown based on the Python networkx framework​ -* Nodes and links (edges) host Python classes in order to present attributes and methods (and hence, a more representative model of a platform / system)​ -* A ‘green agent’ Information Exchange Requirement (IER) function allows the representation of traffic (protocols and loading) on any / all links. Application of IERs is based on the status of node operating systems and services​ -* A ‘green agent’ node Pattern-of-Life (PoL) function allows the representation of core behaviours on nodes (e.g. Hardware state, Software State, Service state, File System state)​ -* An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP, destination IP, protocol and port). Application of IERs adheres to any ACL restrictions​ -* Presents an OpenAI Gym interface to the environment, allowing integration with any OpenAI Gym compliant defensive agents ​ +* A flexible network / system laydown based on the Python networkx framework +* Nodes and links (edges) host Python classes in order to present attributes and methods (and hence, a more representative model of a platform / system) +* A ‘green agent’ Information Exchange Requirement (IER) function allows the representation of traffic (protocols and loading) on any / all links. Application of IERs is based on the status of node operating systems and services +* A ‘green agent’ node Pattern-of-Life (PoL) function allows the representation of core behaviours on nodes (e.g. Hardware state, Software State, Service state, File System state) +* An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP, destination IP, protocol and port). Application of IERs adheres to any ACL restrictions +* Presents an OpenAI Gym interface to the environment, allowing integration with any OpenAI Gym compliant defensive agents * Red agent activity based on ‘red’ IERs and ‘red’ PoL -* Defined reward function for use with RL agents (based on nodes status, and green / red IER success)​ -* Fully configurable (network / system laydown, IERs, node PoL, ACL, episode step period, episode max steps) and repeatable to suit the training requirements of agents. Therefore, not bound to a representation of any particular platform, system or technology​ -* Full capture of discrete metrics relating to agent training (full system state, agent actions taken, average reward)​ -* Networkx provides laydown visualisation capability ​ +* Defined reward function for use with RL agents (based on nodes status, and green / red IER success) +* Fully configurable (network / system laydown, IERs, node PoL, ACL, episode step period, episode max steps) and repeatable to suit the training requirements of agents. Therefore, not bound to a representation of any particular platform, system or technology +* Full capture of discrete metrics relating to agent training (full system state, agent actions taken, average reward) +* Networkx provides laydown visualisation capability Architecture - Nodes and Links ****************************** **Nodes** -An inheritance model has been adopted in order to model nodes. All nodes have the following base attributes (Class: Node):​ +An inheritance model has been adopted in order to model nodes. All nodes have the following base attributes (Class: Node): -* ID​ +* ID * Name -* Type (e.g. computer, switch, RTU - enumeration)​ -* Priority (P1, P2, P3, P4 or P5 - enumeration)​ -* Hardware State (ON, OFF, RESETTING - enumeration)​ +* Type (e.g. computer, switch, RTU - enumeration) +* Priority (P1, P2, P3, P4 or P5 - enumeration) +* Hardware State (ON, OFF, RESETTING - enumeration) -Active Nodes also have the following attributes (Class: Active Node):​ +Active Nodes also have the following attributes (Class: Active Node): -* IP Address​ -* Software State (GOOD, PATCHING, COMPROMISED - enumeration)​ +* IP Address +* Software State (GOOD, PATCHING, COMPROMISED - enumeration) * File System State (GOOD, CORRUPT, DESTROYED, REPAIRING, RESTORING - enumeration) -Service Nodes also have the following attributes (Class: Service Node)​: +Service Nodes also have the following attributes (Class: Service Node): -* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type)​ +* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type) * Service state (GOOD, PATCHING, COMPROMISED, OVERWHELMED - enumeration) Passive Nodes are currently not used (but may be employed for non IP-based components such as machinery actuators in future releases). **Links** -Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality​. Links include the following attributes:​ +Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality. Links include the following attributes: -* ID​ +* ID * Name -* Bandwidth (bits/s)​ +* Bandwidth (bits/s) * Source node ID * Destination node ID * Protocol list (containing the loading of protocols currently running on the link) @@ -62,32 +62,32 @@ When the simulation runs, IERs are applied to the links in order to model traffi Information Exchange Requirements (IERs) **************************************** -PrimAITE adopts the concept of Information Exchange Requirements (IERs) to model both green agent (background) and red agent (adversary) behaviour. IERs are used to initiate modelling of traffic loading on the network, and have the following attributes:​ +PrimAITE adopts the concept of Information Exchange Requirements (IERs) to model both green agent (background) and red agent (adversary) behaviour. IERs are used to initiate modelling of traffic loading on the network, and have the following attributes: -* ID​ -* Start step (i.e. which step in the training episode should the IER start)​ -* End step​ (i.e. which step in the training episode should the IER end) +* ID +* Start step (i.e. which step in the training episode should the IER start) +* End step (i.e. which step in the training episode should the IER end) * Source node ID -* Destination node ID​ -* Load (bits/s)​ -* Protocol​ -* Port​ +* Destination node ID +* Load (bits/s) +* Protocol +* Port * Running status (i.e. on / off) -The application of green agent IERs between a source and destination follows a number of rules. Specifically:​ +The application of green agent IERs between a source and destination follows a number of rules. Specifically: -1. Does the current simulation time step fall between IER start and end step​ -2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)​ -3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)​ -4. Are there any Access Control List rules in place that prevent the application of this IER​ -5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)​ +1. Does the current simulation time step fall between IER start and end step +2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING) +3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING) +4. Are there any Access Control List rules in place that prevent the application of this IER +5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level) -For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically:​ +For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically: -1. Does the current simulation time step fall between IER start and end step​ -2. Is the source node operational, and is the service (protocol / port) associated with the IER (a) present on that node and (b) already in a compromised state​ -3. Is the destination node operational, and is the service (protocol / port) associated with the IER present on that node​ -4. Are there any Access Control List rules in place that prevent the application of this IER​ +1. Does the current simulation time step fall between IER start and end step +2. Is the source node operational, and is the service (protocol / port) associated with the IER (a) present on that node and (b) already in a compromised state +3. Is the destination node operational, and is the service (protocol / port) associated with the IER present on that node +4. Are there any Access Control List rules in place that prevent the application of this IER 5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level) Assuming the rules pass, the IER is applied to all relevant links (based on use of OSPF) between source and destination. @@ -149,7 +149,7 @@ Red agent pattern-of-life has an additional feature not found in the green patte Access Control List modelling ***************************** -An Access Control List (ACL) is modelled to provide the means to manage traffic flows in the system. This will allow defensive agents the means to turn on / off rules, or potentially create new rules, to counter an attack​. +An Access Control List (ACL) is modelled to provide the means to manage traffic flows in the system. This will allow defensive agents the means to turn on / off rules, or potentially create new rules, to counter an attack. The ACL follows a standard network firewall format. For example: @@ -183,9 +183,9 @@ All ACL rules are considered when applying an IER. Logic follows the order of ru Observation Spaces ****************** -The OpenAI Gym observation space provides the status of all nodes and links across the whole system:​ +The OpenAI Gym observation space provides the status of all nodes and links across the whole system: -* Nodes (in terms of hardware state, Software State, file system state and services state) ​ +* Nodes (in terms of hardware state, Software State, file system state and services state) * Links (in terms of current loading for each service/protocol) The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config. diff --git a/docs/source/config.rst b/docs/source/config.rst index 88399973..dec8af85 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -5,17 +5,22 @@ The Config Files Explained PrimAITE uses two configuration files for its operation: -* config_main.yaml - used to define the top-level settings of the PrimAITE environment, and the session that is to be run. -* config_[name].yaml - used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs), Access Control Rules, Action Space type, and the number of steps in each episode. +* **The Training Config** -config_main.yaml: -***************** + Used to define the top-level settings of the PrimAITE environment, the reward values, and the session that is to be run. -The config_main.yaml file consists of the following attributes: +* **The Lay Down Config** + + Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules. + +Environment Config: +******************* + +The environment config file consists of the following attributes: **Generic Config Values** -* **agentIdentifier** [enum] +* **agent_identifier** [enum] This identifies the agent to use for the session. Select from one of the following: @@ -23,61 +28,68 @@ The config_main.yaml file consists of the following attributes: * STABLE_BASELINES3_PPO - Use a SB3 PPO agent * STABLE_BASELINES3_A2C - use a SB3 A2C agent -* **numEpisodes** [int] +* **action_type** [enum] - This defines the number of episodes that the agent will train or be evaluated over. Each episode consists of a number of steps (with step number defined in the config_[name].yaml file) + Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session -* **timeDelay** [int] + +* **num_episodes** [int] + + This defines the number of episodes that the agent will train or be evaluated over. + +* **num_stepss** [int] + + Determines the number of steps to run in each episode of the session + + +* **time_delay** [int] The time delay (in milliseconds) to take between each step when running a GENERIC agent session -* **configFilename** [filename] - The name of the config_[name].yaml file to use for this session - -* **sessionType** [text] +* **session_type** [text] Type of session to be run (TRAINING or EVALUATION) -* **loadAgent** [bool] +* **load_agent** [bool] Determine whether to load an agent from file -* **agentLoadFile** [text] +* **agent_load_file** [text] File path and file name of agent if you're loading one in -* **observationSpaceHighValue** [int] +* **observation_space_high_value** [int] The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases **Reward-Based Config Values** -* **Generic [allOk]** [int] +* **Generic [all_ok]** [int] The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) -* **Node Hardware State [offShouldBeOn]** [int] +* **Node Hardware State [off_should_be_on]** [int] The score to give when the node should be on, but is off -* **Node Hardware State [offShouldBeResetting]** [int] +* **Node Hardware State [off_should_be_resetting]** [int] The score to give when the node should be resetting, but is off -* **Node Hardware State [onShouldBeOff]** [int] +* **Node Hardware State [on_should_be_off]** [int] The score to give when the node should be off, but is on -* **Node Hardware State [onShouldBeResetting]** [int] +* **Node Hardware State [on_should_be_resetting]** [int] The score to give when the node should be resetting, but is on -* **Node Hardware State [resettingShouldBeOn]** [int] +* **Node Hardware State [resetting_should_be_on]** [int] The score to give when the node should be on, but is resetting -* **Node Hardware State [resettingShouldBeOff]** [int] +* **Node Hardware State [resetting_should_be_off]** [int] The score to give when the node should be off, but is resetting @@ -85,27 +97,27 @@ The config_main.yaml file consists of the following attributes: The score to give when the node is resetting -* **Node Operating System or Service State [goodShouldBePatching]** [int] +* **Node Operating System or Service State [good_should_be_patching]** [int] The score to give when the state should be patching, but is good -* **Node Operating System or Service State [goodShouldBeCompromised]** [int] +* **Node Operating System or Service State [good_should_be_compromised]** [int] The score to give when the state should be compromised, but is good -* **Node Operating System or Service State [goodShouldBeOverwhelmed]** [int] +* **Node Operating System or Service State [good_should_be_overwhelmed]** [int] The score to give when the state should be overwhelmed, but is good -* **Node Operating System or Service State [patchingShouldBeGood]** [int] +* **Node Operating System or Service State [patching_should_be_good]** [int] The score to give when the state should be good, but is patching -* **Node Operating System or Service State [patchingShouldBeCompromised]** [int] +* **Node Operating System or Service State [patching_should_be_compromised]** [int] The score to give when the state should be compromised, but is patching -* **Node Operating System or Service State [patchingShouldBeOverwhelmed]** [int] +* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] The score to give when the state should be overwhelmed, but is patching @@ -113,15 +125,15 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is patching -* **Node Operating System or Service State [compromisedShouldBeGood]** [int] +* **Node Operating System or Service State [compromised_should_be_good]** [int] The score to give when the state should be good, but is compromised -* **Node Operating System or Service State [compromisedShouldBePatching]** [int] +* **Node Operating System or Service State [compromised_should_be_patching]** [int] The score to give when the state should be patching, but is compromised -* **Node Operating System or Service State [compromisedShouldBeOverwhelmed]** [int] +* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] The score to give when the state should be overwhelmed, but is compromised @@ -129,15 +141,15 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is compromised -* **Node Operating System or Service State [overwhelmedShouldBeGood]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_good]** [int] The score to give when the state should be good, but is overwhelmed -* **Node Operating System or Service State [overwhelmedShouldBePatching]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] The score to give when the state should be patching, but is overwhelmed -* **Node Operating System or Service State [overwhelmedShouldBeCompromised]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] The score to give when the state should be compromised, but is overwhelmed @@ -145,35 +157,35 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is overwhelmed -* **Node File System State [goodShouldBeRepairing]** [int] +* **Node File System State [good_should_be_repairing]** [int] The score to give when the state should be repairing, but is good -* **Node File System State [goodShouldBeRestoring]** [int] +* **Node File System State [good_should_be_restoring]** [int] The score to give when the state should be restoring, but is good -* **Node File System State [goodShouldBeCorrupt]** [int] +* **Node File System State [good_should_be_corrupt]** [int] The score to give when the state should be corrupt, but is good -* **Node File System State [goodShouldBeDestroyed]** [int] +* **Node File System State [good_should_be_destroyed]** [int] The score to give when the state should be destroyed, but is good -* **Node File System State [repairingShouldBeGood]** [int] +* **Node File System State [repairing_should_be_good]** [int] The score to give when the state should be good, but is repairing -* **Node File System State [repairingShouldBeRestoring]** [int] +* **Node File System State [repairing_should_be_restoring]** [int] The score to give when the state should be restoring, but is repairing -* **Node File System State [repairingShouldBeCorrupt]** [int] +* **Node File System State [repairing_should_be_corrupt]** [int] The score to give when the state should be corrupt, but is repairing -* **Node File System State [repairingShouldBeDestroyed]** [int] +* **Node File System State [repairing_should_be_destroyed]** [int] The score to give when the state should be destroyed, but is repairing @@ -181,19 +193,19 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is repairing -* **Node File System State [restoringShouldBeGood]** [int] +* **Node File System State [restoring_should_be_good]** [int] The score to give when the state should be good, but is restoring -* **Node File System State [restoringShouldBeRepairing]** [int] +* **Node File System State [restoring_should_be_repairing]** [int] The score to give when the state should be repairing, but is restoring -* **Node File System State [restoringShouldBeCorrupt]** [int] +* **Node File System State [restoring_should_be_corrupt]** [int] The score to give when the state should be corrupt, but is restoring -* **Node File System State [restoringShouldBeDestroyed]** [int] +* **Node File System State [restoring_should_be_destroyed]** [int] The score to give when the state should be destroyed, but is restoring @@ -201,19 +213,19 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is restoring -* **Node File System State [corruptShouldBeGood]** [int] +* **Node File System State [corrupt_should_be_good]** [int] The score to give when the state should be good, but is corrupt -* **Node File System State [corruptShouldBeRepairing]** [int] +* **Node File System State [corrupt_should_be_repairing]** [int] The score to give when the state should be repairing, but is corrupt -* **Node File System State [corruptShouldBeRestoring]** [int] +* **Node File System State [corrupt_should_be_restoring]** [int] The score to give when the state should be restoring, but is corrupt -* **Node File System State [corruptShouldBeDestroyed]** [int] +* **Node File System State [corrupt_should_be_destroyed]** [int] The score to give when the state should be destroyed, but is corrupt @@ -221,19 +233,19 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is corrupt -* **Node File System State [destroyedShouldBeGood]** [int] +* **Node File System State [destroyed_should_be_good]** [int] The score to give when the state should be good, but is destroyed -* **Node File System State [destroyedShouldBeRepairing]** [int] +* **Node File System State [destroyed_should_be_repairing]** [int] The score to give when the state should be repairing, but is destroyed -* **Node File System State [destroyedShouldBeRestoring]** [int] +* **Node File System State [destroyed_should_be_restoring]** [int] The score to give when the state should be restoring, but is destroyed -* **Node File System State [destroyedShouldBeCorrupt]** [int] +* **Node File System State [destroyed_should_be_corrupt]** [int] The score to give when the state should be corrupt, but is destroyed @@ -245,52 +257,44 @@ The config_main.yaml file consists of the following attributes: The score to give when the state is scanning -* **IER Status [redIerRunning]** [int] +* **IER Status [red_ier_running]** [int] The score to give when a red agent IER is permitted to run -* **IER Status [greenIerBlocked]** [int] +* **IER Status [green_ier_blocked]** [int] The score to give when a green agent IER is prevented from running **Patching / Reset Durations** -* **osPatchingDuration** [int] +* **os_patching_duration** [int] The number of steps to take when patching an Operating System -* **nodeResetDuration** [int] +* **node_reset_duration** [int] The number of steps to take when resetting a node's hardware state -* **servicePatchingDuration** [int] +* **service_patching_duration** [int] The number of steps to take when patching a service -* **fileSystemRepairingLimit** [int]: +* **file_system_repairing_limit** [int]: The number of steps to take when repairing the file system -* **fileSystemRestoringLimit** [int] +* **file_system_restoring_limit** [int] The number of steps to take when restoring the file system -* **fileSystemScanningLimit** [int] +* **file_system_scanning_limit** [int] The number of steps to take when scanning the file system -config_[name].yaml: +The Lay Down Config ******************* -The config_[name].yaml file consists of the following attributes: - -* **itemType: ACTIONS** [enum] - - Determines whether a NODE or ACL action space format is adopted for the session - -* **itemType: STEPS** [int] - - Determines the number of steps to run in each episode of the session +The lay down config file consists of the following attributes: * **itemType: PORTS** [int] diff --git a/docs/source/session.rst b/docs/source/session.rst index 3e1fb940..5fefb371 100644 --- a/docs/source/session.rst +++ b/docs/source/session.rst @@ -29,10 +29,10 @@ the run_generic function should be selected, and should be modified (typically) .. code:: python - agent = MyAgent(environment, max_steps)​ - for episode in range(0, num_episodes):​ - agent.learn() ​ - env.close()​ + agent = MyAgent(environment, max_steps) + for episode in range(0, num_episodes): + agent.learn() + env.close() save_agent(agent) Where: @@ -51,29 +51,29 @@ environment is reset between episodes. Note that the example below should not be .. code:: python - def learn(self) :​ + def learn(self) : - # pre-reqs​​ + # pre-reqs - # reset the environment​ - self.environment.reset()​ - done = False​ + # reset the environment + self.environment.reset() + done = False - for step in range(max_steps):​ - # calculate the action​ + for step in range(max_steps): + # calculate the action action = ... - ​# execute the environment step​ - new_state, reward, done, info = self.environment.step(action)​ + # execute the environment step + new_state, reward, done, info = self.environment.step(action) - # algorithm updates​ + # algorithm updates ... - # update to our new state​ - state = new_state​ + # update to our new state + state = new_state - # if done, finish episode​ - if done == True:​ + # if done, finish episode + if done == True: break **Running the session** diff --git a/pyproject.toml b/pyproject.toml index 48c2e2f2..812e35c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,3 +59,6 @@ dev = [ "wheel==0.38.4", "build==0.10.0" ] + +[project.scripts] +primaite = "primaite.cli:app" diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py new file mode 100644 index 00000000..3c38a53b --- /dev/null +++ b/src/primaite/__init__.py @@ -0,0 +1,98 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import logging +import logging.config +import sys +from logging import Logger, StreamHandler +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Final + +from platformdirs import PlatformDirs + +_PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") +"""An instance of `PlatformDirs` set with appname='primaite'.""" + +_USER_DIRS: Final[Path] = Path.home() / "primaite" +"""The users home space for PrimAITE which is located at: ~/primaite.""" + +NOTEBOOKS_DIR: Final[Path] = _USER_DIRS / "notebooks" +""" +The path to the users notebooks directory as an instance of `Path` or +`PosixPath`, depending on the OS. + +Users notebooks are stored at: ``~/primaite/notebooks``. +""" + +USERS_CONFIG_DIR: Final[Path] = _USER_DIRS / "config" +""" +The path to the users config directory as an instance of `Path` or +`PosixPath`, depending on the OS. + +Users config files are stored at: ``~/primaite/config``. +""" + +SESSIONS_DIR: Final[Path] = _USER_DIRS / "sessions" +""" +The path to the users PrimAITE Sessions directory as an instance of `Path` or +`PosixPath`, depending on the OS. + +Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. +""" + + +# region Setup Logging +def _log_dir() -> Path: + if sys.platform == "win32": + dir_path = _PLATFORM_DIRS.user_data_path / "logs" + else: + dir_path = _PLATFORM_DIRS.user_log_path + return dir_path + + +LOG_DIR: Final[Path] = _log_dir() +"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS.""" + +LOG_PATH: Final[Path] = LOG_DIR / "primaite.log" +"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS.""" + +_STREAM_HANDLER: Final[StreamHandler] = StreamHandler() +_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( + filename=LOG_PATH, + maxBytes=10485760, # 10MB + backupCount=9, # Max 100MB of logs + encoding="utf8", +) +_STREAM_HANDLER.setLevel(logging.INFO) +_FILE_HANDLER.setLevel(logging.INFO) + +_LOG_FORMAT_STR: Final[ + str +] = "%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s" +_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) +_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) + +_LOGGER = logging.getLogger(__name__) + +_LOGGER.addHandler(_STREAM_HANDLER) +_LOGGER.addHandler(_FILE_HANDLER) + + +def getLogger(name: str) -> Logger: + """ + Get a PrimAITE logger. + + :param name: The logger name. Use ``__name__``. + :return: An instance of :py:class:`logging.Logger` with the PrimAITE + logging config. + """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + return logger + + +# endregion + + +with open(Path(__file__).parent.resolve() / "VERSION", "r") as file: + __version__ = file.readline() diff --git a/src/primaite/cli.py b/src/primaite/cli.py new file mode 100644 index 00000000..ebc126e0 --- /dev/null +++ b/src/primaite/cli.py @@ -0,0 +1,125 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +"""Provides a CLI using Typer as an entry point.""" +import os +import sys + +import typer + +app = typer.Typer() + + +@app.command() +def build_dirs(): + """Build the PrimAITE app directories.""" + from primaite.setup import setup_app_dirs + + setup_app_dirs.run() + + +@app.command() +def reset_notebooks(overwrite: bool = True): + """ + Force a reset of the demo notebooks in the users notebooks directory. + + :param overwrite: If True, will overwrite existing demo notebooks. + """ + from primaite.setup import reset_demo_notebooks + + reset_demo_notebooks.run(overwrite) + + +@app.command() +def logs(last_n: int = 10): + """ + Print the PrimAITE log file. + + :param last_n: The number of lines to print. Default value is 10. + """ + import re + + from platformdirs import PlatformDirs + + yt_platform_dirs = PlatformDirs(appname="primaite") + + if sys.platform == "win32": + log_dir = yt_platform_dirs.user_data_path / "logs" + else: + log_dir = yt_platform_dirs.user_log_path + log_path = os.path.join(log_dir, "primaite.log") + + if os.path.isfile(log_path): + with open(log_path) as file: + lines = file.readlines() + for line in lines[-last_n:]: + print(re.sub(r"\n*", "", line)) + + +@app.command() +def notebooks(): + """Start Jupyter Lab in the users PrimAITE notebooks directory.""" + from primaite.notebooks import start_jupyter_session + + start_jupyter_session() + + +@app.command() +def version(): + """Get the installed PrimAITE version number.""" + import primaite + + print(primaite.__version__) + + +@app.command() +def clean_up(): + """Cleans up left over files from previous version installations.""" + from primaite.setup import old_installation_clean_up + + old_installation_clean_up.run() + + +@app.command() +def setup(): + """ + Perform the PrimAITE first-time setup. + + WARNING: All user-data will be lost. + """ + from primaite import getLogger + from primaite.setup import ( + old_installation_clean_up, + reset_demo_notebooks, + reset_example_configs, + setup_app_dirs, + ) + + _LOGGER = getLogger(__name__) + + _LOGGER.info("Performing the PrimAITE first-time setup...") + + _LOGGER.info("Building the PrimAITE app directories...") + setup_app_dirs.run() + + _LOGGER.info("Rebuilding the demo notebooks...") + reset_demo_notebooks.run(overwrite_existing=True) + + _LOGGER.info("Rebuilding the example notebooks...") + reset_example_configs.run(overwrite_existing=True) + + _LOGGER.info("Performing a clean-up of previous PrimAITE installations...") + old_installation_clean_up.run() + + _LOGGER.info("PrimAITE setup complete!") + + +@app.command() +def session(tc: str, ldc: str): + """ + Run a PrimAITE session. + + :param tc: The training config filepath. + :param ldc: The lay down config file path. + """ + from primaite.main import run + + run(training_config_path=tc, lay_down_config_path=ldc) diff --git a/src/primaite/common/config_values_main.py b/src/primaite/common/config_values_main.py deleted file mode 100644 index 3493f9d2..00000000 --- a/src/primaite/common/config_values_main.py +++ /dev/null @@ -1,90 +0,0 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""The config class.""" - - -class ConfigValuesMain(object): - """Class to hold main config values.""" - - def __init__(self): - """Init.""" - # Generic - self.agent_identifier = "" # the agent in use - self.num_episodes = 0 # number of episodes to train over - self.num_steps = 0 # number of steps in an episode - self.time_delay = 0 # delay between steps (ms) - applies to generic agents only - self.config_filename_use_case = "" # the filename for the Use Case config file - self.session_type = "" # the session type to run (TRAINING or EVALUATION) - - # Environment - self.observation_space_high_value = ( - 0 # The high value for the observation space - ) - - # Reward values - # Generic - self.all_ok = 0 - # Node Hardware State - self.off_should_be_on = 0 - self.off_should_be_resetting = 0 - self.on_should_be_off = 0 - self.on_should_be_resetting = 0 - self.resetting_should_be_on = 0 - self.resetting_should_be_off = 0 - self.resetting = 0 - # Node Software or Service State - self.good_should_be_patching = 0 - self.good_should_be_compromised = 0 - self.good_should_be_overwhelmed = 0 - self.patching_should_be_good = 0 - self.patching_should_be_compromised = 0 - self.patching_should_be_overwhelmed = 0 - self.patching = 0 - self.compromised_should_be_good = 0 - self.compromised_should_be_patching = 0 - self.compromised_should_be_overwhelmed = 0 - self.compromised = 0 - self.overwhelmed_should_be_good = 0 - self.overwhelmed_should_be_patching = 0 - self.overwhelmed_should_be_compromised = 0 - self.overwhelmed = 0 - # Node File System State - self.good_should_be_repairing = 0 - self.good_should_be_restoring = 0 - self.good_should_be_corrupt = 0 - self.good_should_be_destroyed = 0 - self.repairing_should_be_good = 0 - self.repairing_should_be_restoring = 0 - self.repairing_should_be_corrupt = 0 - self.repairing_should_be_destroyed = ( - 0 # Repairing does not fix destroyed state - you need to restore - ) - self.repairing = 0 - self.restoring_should_be_good = 0 - self.restoring_should_be_repairing = 0 - self.restoring_should_be_corrupt = ( - 0 # Not the optimal method (as repair will fix corruption) - ) - self.restoring_should_be_destroyed = 0 - self.restoring = 0 - self.corrupt_should_be_good = 0 - self.corrupt_should_be_repairing = 0 - self.corrupt_should_be_restoring = 0 - self.corrupt_should_be_destroyed = 0 - self.corrupt = 0 - self.destroyed_should_be_good = 0 - self.destroyed_should_be_repairing = 0 - self.destroyed_should_be_restoring = 0 - self.destroyed_should_be_corrupt = 0 - self.destroyed = 0 - self.scanning = 0 - # IER status - self.red_ier_running = 0 - self.green_ier_blocked = 0 - - # Patching / Reset - self.os_patching_duration = 0 # The time taken to patch the OS - self.node_reset_duration = 0 # The time taken to reset a node (hardware) - self.service_patching_duration = 0 # The time taken to patch a service - self.file_system_repairing_limit = 0 # The time take to repair a file - self.file_system_restoring_limit = 0 # The time take to restore a file - self.file_system_scanning_limit = 0 # The time taken to scan the file system diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 138d2742..0c3256d1 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -81,6 +81,7 @@ class ActionType(Enum): NODE = 0 ACL = 1 + ANY = 2 class ObservationType(Enum): diff --git a/src/primaite/common/training_config.py b/src/primaite/common/training_config.py new file mode 100644 index 00000000..347f1c7a --- /dev/null +++ b/src/primaite/common/training_config.py @@ -0,0 +1,91 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +"""The config class.""" +from dataclasses import dataclass + +from primaite.common.enums import ActionType + + +@dataclass() +class TrainingConfig: + """Class to hold main config values.""" + + # Generic + agent_identifier: str # The Red Agent algo/class to be used + action_type: ActionType # type of action to use (NODE/ACL/ANY) + num_episodes: int # number of episodes to train over + num_steps: int # number of steps in an episode + time_delay: int # delay between steps (ms) - applies to generic agents only + # file + session_type: str # the session type to run (TRAINING or EVALUATION) + load_agent: str # Determine whether to load an agent from file + agent_load_file: str # File path and file name of agent if you're loading one in + + # Environment + observation_space_high_value: int # The high value for the observation space + + # Reward values + # Generic + all_ok: int + # Node Hardware State + off_should_be_on: int + off_should_be_resetting: int + on_should_be_off: int + on_should_be_resetting: int + resetting_should_be_on: int + resetting_should_be_off: int + resetting: int + # Node Software or Service State + good_should_be_patching: int + good_should_be_compromised: int + good_should_be_overwhelmed: int + patching_should_be_good: int + patching_should_be_compromised: int + patching_should_be_overwhelmed: int + patching: int + compromised_should_be_good: int + compromised_should_be_patching: int + compromised_should_be_overwhelmed: int + compromised: int + overwhelmed_should_be_good: int + overwhelmed_should_be_patching: int + overwhelmed_should_be_compromised: int + overwhelmed: int + # Node File System State + good_should_be_repairing: int + good_should_be_restoring: int + good_should_be_corrupt: int + good_should_be_destroyed: int + repairing_should_be_good: int + repairing_should_be_restoring: int + repairing_should_be_corrupt: int + repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore + + repairing: int + restoring_should_be_good: int + restoring_should_be_repairing: int + restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption) + + restoring_should_be_destroyed: int + restoring: int + corrupt_should_be_good: int + corrupt_should_be_repairing: int + corrupt_should_be_restoring: int + corrupt_should_be_destroyed: int + corrupt: int + destroyed_should_be_good: int + destroyed_should_be_repairing: int + destroyed_should_be_restoring: int + destroyed_should_be_corrupt: int + destroyed: int + scanning: int + # IER status + red_ier_running: int + green_ier_blocked: int + + # Patching / Reset + os_patching_duration: int # The time taken to patch the OS + node_reset_duration: int # The time taken to reset a node (hardware) + service_patching_duration: int # The time taken to patch a service + file_system_repairing_limit: int # The time take to repair a file + file_system_restoring_limit: int # The time take to restore a file + file_system_scanning_limit: int # The time taken to scan the file system diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml similarity index 100% rename from src/primaite/config/config_1_DDOS_BASIC.yaml rename to src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml diff --git a/src/primaite/config/config_2_DDOS_BASIC.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml similarity index 100% rename from src/primaite/config/config_2_DDOS_BASIC.yaml rename to src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml diff --git a/src/primaite/config/config_3_DOS_VERY_BASIC.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml similarity index 100% rename from src/primaite/config/config_3_DOS_VERY_BASIC.yaml rename to src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml diff --git a/src/primaite/config/config_5_DATA_MANIPULATION.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml similarity index 99% rename from src/primaite/config/config_5_DATA_MANIPULATION.yaml rename to src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml index 3b29ff4a..1316ccd1 100644 --- a/src/primaite/config/config_5_DATA_MANIPULATION.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml @@ -1,7 +1,4 @@ -- itemType: ACTIONS - type: NODE -- itemType: STEPS - steps: 256 + - itemType: PORTS portsList: - port: '80' diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml new file mode 100644 index 00000000..d01f51f3 --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -0,0 +1,94 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/config/config_UNIT_TEST.yaml b/src/primaite/config/config_UNIT_TEST.yaml deleted file mode 100644 index 3b29ff4a..00000000 --- a/src/primaite/config/config_UNIT_TEST.yaml +++ /dev/null @@ -1,533 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: STEPS - steps: 256 -- itemType: PORTS - portsList: - - port: '80' - - port: '1433' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: TCP_SQL - - name: UDP -- itemType: NODE - node_id: '1' - name: CLIENT_1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.11 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: CLIENT_2 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.12 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH_1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.10.1 - software_state: GOOD - file_system_state: GOOD -- itemType: NODE - node_id: '4' - name: SECURITY_SUITE - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.10 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '5' - name: MANAGEMENT_CONSOLE - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.12 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '6' - name: SWITCH_2 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.2.1 - software_state: GOOD - file_system_state: GOOD -- itemType: NODE - node_id: '7' - name: WEB_SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.2.10 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: TCP_SQL - port: '1433' - state: GOOD -- itemType: NODE - node_id: '8' - name: DATABASE_SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.2.14 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: TCP_SQL - port: '1433' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '9' - name: BACKUP_SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.2.16 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- itemType: LINK - id: '10' - name: LINK_1 - bandwidth: 1000000000 - source: '1' - destination: '3' -- itemType: LINK - id: '11' - name: LINK_2 - bandwidth: 1000000000 - source: '2' - destination: '3' -- itemType: LINK - id: '12' - name: LINK_3 - bandwidth: 1000000000 - source: '3' - destination: '4' -- itemType: LINK - id: '13' - name: LINK_4 - bandwidth: 1000000000 - source: '3' - destination: '5' -- itemType: LINK - id: '14' - name: LINK_5 - bandwidth: 1000000000 - source: '4' - destination: '6' -- itemType: LINK - id: '15' - name: LINK_6 - bandwidth: 1000000000 - source: '5' - destination: '6' -- itemType: LINK - id: '16' - name: LINK_7 - bandwidth: 1000000000 - source: '6' - destination: '7' -- itemType: LINK - id: '17' - name: LINK_8 - bandwidth: 1000000000 - source: '6' - destination: '8' -- itemType: LINK - id: '18' - name: LINK_9 - bandwidth: 1000000000 - source: '6' - destination: '9' -- itemType: GREEN_IER - id: '19' - startStep: 1 - endStep: 256 - load: 10000 - protocol: TCP - port: '80' - source: '1' - destination: '7' - missionCriticality: 5 -- itemType: GREEN_IER - id: '20' - startStep: 1 - endStep: 256 - load: 10000 - protocol: TCP - port: '80' - source: '7' - destination: '1' - missionCriticality: 5 -- itemType: GREEN_IER - id: '21' - startStep: 1 - endStep: 256 - load: 10000 - protocol: TCP - port: '80' - source: '2' - destination: '7' - missionCriticality: 5 -- itemType: GREEN_IER - id: '22' - startStep: 1 - endStep: 256 - load: 10000 - protocol: TCP - port: '80' - source: '7' - destination: '2' - missionCriticality: 5 -- itemType: GREEN_IER - id: '23' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP_SQL - port: '1433' - source: '7' - destination: '8' - missionCriticality: 5 -- itemType: GREEN_IER - id: '24' - startStep: 1 - endStep: 256 - load: 100000 - protocol: TCP_SQL - port: '1433' - source: '8' - destination: '7' - missionCriticality: 5 -- itemType: GREEN_IER - id: '25' - startStep: 1 - endStep: 256 - load: 50000 - protocol: TCP - port: '80' - source: '1' - destination: '9' - missionCriticality: 2 -- itemType: GREEN_IER - id: '26' - startStep: 1 - endStep: 256 - load: 50000 - protocol: TCP - port: '80' - source: '2' - destination: '9' - missionCriticality: 2 -- itemType: GREEN_IER - id: '27' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP - port: '80' - source: '5' - destination: '7' - missionCriticality: 1 -- itemType: GREEN_IER - id: '28' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP - port: '80' - source: '7' - destination: '5' - missionCriticality: 1 -- itemType: GREEN_IER - id: '29' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP - port: '80' - source: '5' - destination: '8' - missionCriticality: 1 -- itemType: GREEN_IER - id: '30' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP - port: '80' - source: '8' - destination: '5' - missionCriticality: 1 -- itemType: GREEN_IER - id: '31' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP - port: '80' - source: '5' - destination: '9' - missionCriticality: 1 -- itemType: GREEN_IER - id: '32' - startStep: 1 - endStep: 256 - load: 5000 - protocol: TCP - port: '80' - source: '9' - destination: '5' - missionCriticality: 1 -- itemType: ACL_RULE - id: '33' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.2.10 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '34' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.2.14 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '35' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.2.14 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '36' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.2.10 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '37' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.10.11 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '38' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.10.12 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '39' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.2.14 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '40' - permission: ALLOW - source: 192.168.2.14 - destination: 192.168.2.10 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '41' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.2.16 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '42' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.2.16 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '43' - permission: ALLOW - source: 192.168.1.12 - destination: 192.168.2.10 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '44' - permission: ALLOW - source: 192.168.1.12 - destination: 192.168.2.14 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '45' - permission: ALLOW - source: 192.168.1.12 - destination: 192.168.2.16 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '46' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.1.12 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '47' - permission: ALLOW - source: 192.168.2.14 - destination: 192.168.1.12 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '48' - permission: ALLOW - source: 192.168.2.16 - destination: 192.168.1.12 - protocol: ANY - port: ANY -- itemType: ACL_RULE - id: '49' - permission: DENY - source: ANY - destination: ANY - protocol: ANY - port: ANY -- itemType: RED_POL - id: '50' - startStep: 50 - endStep: 50 - targetNodeId: '1' - initiator: DIRECT - type: SERVICE - protocol: UDP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- itemType: RED_IER - id: '51' - startStep: 75 - endStep: 105 - load: 10000 - protocol: UDP - port: '53' - source: '1' - destination: '8' - missionCriticality: 0 -- itemType: RED_POL - id: '52' - startStep: 100 - endStep: 100 - targetNodeId: '8' - initiator: IER - type: SERVICE - protocol: UDP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- itemType: RED_POL - id: '53' - startStep: 105 - endStep: 105 - targetNodeId: '8' - initiator: SERVICE - type: FILE - protocol: NA - state: CORRUPT - sourceNodeId: '8' - sourceNodeService: UDP - sourceNodeServiceState: COMPROMISED -- itemType: RED_POL - id: '54' - startStep: 105 - endStep: 105 - targetNodeId: '8' - initiator: SERVICE - type: SERVICE - protocol: TCP_SQL - state: COMPROMISED - sourceNodeId: '8' - sourceNodeService: UDP - sourceNodeServiceState: COMPROMISED -- itemType: RED_POL - id: '55' - startStep: 125 - endStep: 125 - targetNodeId: '7' - initiator: SERVICE - type: SERVICE - protocol: TCP - state: OVERWHELMED - sourceNodeId: '8' - sourceNodeService: TCP_SQL - sourceNodeServiceState: COMPROMISED diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py new file mode 100644 index 00000000..b33676a8 --- /dev/null +++ b/src/primaite/config/lay_down_config.py @@ -0,0 +1,69 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from pathlib import Path +from typing import Final + +from primaite import getLogger, USERS_CONFIG_DIR + +_LOGGER = getLogger(__name__) + +_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" + + +def ddos_basic_one_config_path() -> Path: + """ + The path to the example lay_down_config_1_DDOS_basic.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml" + if not path.exists(): + msg = f"Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path + + +def ddos_basic_two_config_path() -> Path: + """ + The path to the example lay_down_config_2_DDOS_basic.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml" + if not path.exists(): + msg = f"Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path + + +def dos_very_basic_config_path() -> Path: + """ + The path to the example lay_down_config_3_DOS_very_basic.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml" + if not path.exists(): + msg = f"Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path + + +def data_manipulation_config_path() -> Path: + """ + The path to the example lay_down_config_5_data_manipulation.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml" + if not path.exists(): + msg = f"Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py new file mode 100644 index 00000000..f4ac1c69 --- /dev/null +++ b/src/primaite/config/training_config.py @@ -0,0 +1,260 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Union, Final + +import yaml + +from primaite import getLogger, USERS_CONFIG_DIR +from primaite.common.enums import ActionType + +_LOGGER = getLogger(__name__) + +_EXAMPLE_TRAINING: Final[ + Path] = USERS_CONFIG_DIR / "example_config" / "training" + + +@dataclass() +class TrainingConfig: + """The Training Config class.""" + + # Generic + agent_identifier: str # The Red Agent algo/class to be used + action_type: ActionType # type of action to use (NODE/ACL/ANY) + num_episodes: int # number of episodes to train over + num_steps: int # number of steps in an episode + time_delay: int # delay between steps (ms) - applies to generic agents only + # file + session_type: str # the session type to run (TRAINING or EVALUATION) + load_agent: str # Determine whether to load an agent from file + agent_load_file: str # File path and file name of agent if you're loading one in + + # Environment + observation_space_high_value: int # The high value for the observation space + + # Reward values + # Generic + all_ok: int + # Node Hardware State + off_should_be_on: int + off_should_be_resetting: int + on_should_be_off: int + on_should_be_resetting: int + resetting_should_be_on: int + resetting_should_be_off: int + resetting: int + # Node Software or Service State + good_should_be_patching: int + good_should_be_compromised: int + good_should_be_overwhelmed: int + patching_should_be_good: int + patching_should_be_compromised: int + patching_should_be_overwhelmed: int + patching: int + compromised_should_be_good: int + compromised_should_be_patching: int + compromised_should_be_overwhelmed: int + compromised: int + overwhelmed_should_be_good: int + overwhelmed_should_be_patching: int + overwhelmed_should_be_compromised: int + overwhelmed: int + # Node File System State + good_should_be_repairing: int + good_should_be_restoring: int + good_should_be_corrupt: int + good_should_be_destroyed: int + repairing_should_be_good: int + repairing_should_be_restoring: int + repairing_should_be_corrupt: int + repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore + + repairing: int + restoring_should_be_good: int + restoring_should_be_repairing: int + restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption) + + restoring_should_be_destroyed: int + restoring: int + corrupt_should_be_good: int + corrupt_should_be_repairing: int + corrupt_should_be_restoring: int + corrupt_should_be_destroyed: int + corrupt: int + destroyed_should_be_good: int + destroyed_should_be_repairing: int + destroyed_should_be_restoring: int + destroyed_should_be_corrupt: int + destroyed: int + scanning: int + # IER status + red_ier_running: int + green_ier_blocked: int + + # Patching / Reset + os_patching_duration: int # The time taken to patch the OS + node_reset_duration: int # The time taken to reset a node (hardware) + service_patching_duration: int # The time taken to patch a service + file_system_repairing_limit: int # The time take to repair a file + file_system_restoring_limit: int # The time take to restore a file + file_system_scanning_limit: int # The time taken to scan the file system + + +def main_training_config_path() -> Path: + """ + The path to the example training_config_main.yaml file + + :return: The file path. + """ + + path = _EXAMPLE_TRAINING / "training_config_main.yaml" + if not path.exists(): + msg = f"Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path + + +def load(file_path: Union[str, Path], + legacy_file: bool = False) -> TrainingConfig: + """ + Read in a training config yaml file. + + :param file_path: The config file path. + :param legacy_file: True if the config file is legacy format, otherwise + False. + :return: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + :raises ValueError: If the file_path does not exist. + :raises TypeError: When the TrainingConfig object cannot be created + using the values from the config file read from ``file_path``. + """ + if not isinstance(file_path, Path): + file_path = Path(file_path) + if file_path.exists(): + with open(file_path, "r") as file: + config = yaml.safe_load(file) + _LOGGER.debug(f"Loading training config file: {file_path}") + if legacy_file: + try: + config = convert_legacy_training_config_dict(config) + except KeyError: + msg = ( + f"Failed to convert training config file {file_path} " + f"from legacy format. Attempting to use file as is." + ) + _LOGGER.error(msg) + # Convert values to Enums + config["action_type"] = ActionType[config["action_type"]] + try: + return TrainingConfig(**config) + except TypeError as e: + msg = ( + f"Error when creating an instance of {TrainingConfig} " + f"from the training config file {file_path}" + ) + _LOGGER.critical(msg, exc_info=True) + raise e + msg = f"Cannot load the training config as it does not exist: {file_path}" + _LOGGER.error(msg) + raise ValueError(msg) + + +def convert_legacy_training_config_dict( + legacy_config_dict: Dict[str, Any], + num_steps: int = 256, + action_type: str = "ANY" +) -> Dict[str, Any]: + """ + Convert a legacy training config dict to the new format. + + :param legacy_config_dict: A legacy training config dict. + :param num_steps: The number of steps to set as legacy training configs + don't have num_steps values. + :param action_type: The action space type to set as legacy training configs + don't have action_type values. + :return: The converted training config dict. + """ + config_dict = {"num_steps": num_steps, "action_type": action_type} + for legacy_key, value in legacy_config_dict.items(): + new_key = _get_new_key_from_legacy(legacy_key) + if new_key: + config_dict[new_key] = value + return config_dict + + +def _get_new_key_from_legacy(legacy_key: str) -> str: + """ + Maps legacy training config keys to the new format keys. + + :param legacy_key: A legacy training config key. + :return: The mapped key. + """ + key_mapping = { + "agentIdentifier": "agent_identifier", + "numEpisodes": "num_episodes", + "timeDelay": "time_delay", + "configFilename": None, + "sessionType": "session_type", + "loadAgent": "load_agent", + "agentLoadFile": "agent_load_file", + "observationSpaceHighValue": "observation_space_high_value", + "allOk": "all_ok", + "offShouldBeOn": "off_should_be_on", + "offShouldBeResetting": "off_should_be_resetting", + "onShouldBeOff": "on_should_be_off", + "onShouldBeResetting": "on_should_be_resetting", + "resettingShouldBeOn": "resetting_should_be_on", + "resettingShouldBeOff": "resetting_should_be_off", + "resetting": "resetting", + "goodShouldBePatching": "good_should_be_patching", + "goodShouldBeCompromised": "good_should_be_compromised", + "goodShouldBeOverwhelmed": "good_should_be_overwhelmed", + "patchingShouldBeGood": "patching_should_be_good", + "patchingShouldBeCompromised": "patching_should_be_compromised", + "patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed", + "patching": "patching", + "compromisedShouldBeGood": "compromised_should_be_good", + "compromisedShouldBePatching": "compromised_should_be_patching", + "compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed", + "compromised": "compromised", + "overwhelmedShouldBeGood": "overwhelmed_should_be_good", + "overwhelmedShouldBePatching": "overwhelmed_should_be_patching", + "overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised", + "overwhelmed": "overwhelmed", + "goodShouldBeRepairing": "good_should_be_repairing", + "goodShouldBeRestoring": "good_should_be_restoring", + "goodShouldBeCorrupt": "good_should_be_corrupt", + "goodShouldBeDestroyed": "good_should_be_destroyed", + "repairingShouldBeGood": "repairing_should_be_good", + "repairingShouldBeRestoring": "repairing_should_be_restoring", + "repairingShouldBeCorrupt": "repairing_should_be_corrupt", + "repairingShouldBeDestroyed": "repairing_should_be_destroyed", + "repairing": "repairing", + "restoringShouldBeGood": "restoring_should_be_good", + "restoringShouldBeRepairing": "restoring_should_be_repairing", + "restoringShouldBeCorrupt": "restoring_should_be_corrupt", + "restoringShouldBeDestroyed": "restoring_should_be_destroyed", + "restoring": "restoring", + "corruptShouldBeGood": "corrupt_should_be_good", + "corruptShouldBeRepairing": "corrupt_should_be_repairing", + "corruptShouldBeRestoring": "corrupt_should_be_restoring", + "corruptShouldBeDestroyed": "corrupt_should_be_destroyed", + "corrupt": "corrupt", + "destroyedShouldBeGood": "destroyed_should_be_good", + "destroyedShouldBeRepairing": "destroyed_should_be_repairing", + "destroyedShouldBeRestoring": "destroyed_should_be_restoring", + "destroyedShouldBeCorrupt": "destroyed_should_be_corrupt", + "destroyed": "destroyed", + "scanning": "scanning", + "redIerRunning": "red_ier_running", + "greenIerBlocked": "green_ier_blocked", + "osPatchingDuration": "os_patching_duration", + "nodeResetDuration": "node_reset_duration", + "servicePatchingDuration": "service_patching_duration", + "fileSystemRepairingLimit": "file_system_repairing_limit", + "fileSystemRestoringLimit": "file_system_restoring_limit", + "fileSystemScanningLimit": "file_system_scanning_limit", + } + return key_mapping[legacy_key] diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 776f1517..1feffc01 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,8 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Tuple +from pathlib import Path +from typing import Dict, Tuple, Union import networkx as nx import numpy as np @@ -28,6 +29,8 @@ from primaite.common.enums import ( SoftwareState, ) from primaite.common.service import Service +from primaite.config import training_config +from primaite.config.training_config import TrainingConfig from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode @@ -56,26 +59,36 @@ class Primaite(Env): OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space - def __init__(self, _config_values, _transaction_list): + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, + ): """ - Init. + The Primaite constructor. - Args: - _episode_steps: The number of steps for the episode - _config_filename: The name of config file - _transaction_list: The list of transactions to populate - _agent_identifier: Identifier for the agent + :param training_config_path: The training config filepath. + :param lay_down_config_path: The lay down config filepath. + :param transaction_list: The list of transactions to populate. + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. """ - super(Primaite, self).__init__() + self._training_config_path = training_config_path + self._lay_down_config_path = lay_down_config_path - # Take a copy of the config values - self.config_values = _config_values + self.config_values: TrainingConfig = training_config.load(training_config_path) # Number of steps in an episode - self.episode_steps = 0 + self.episode_steps = self.config_values.num_steps + + super(Primaite, self).__init__() # Transaction list - self.transaction_list = _transaction_list + self.transaction_list = transaction_list # The agent in use self.agent_identifier = self.config_values.agent_identifier @@ -153,13 +166,10 @@ class Primaite(Env): self.observation_type = ObservationType.BOX # Open the config file and build the environment laydown - try: - self.config_file = open(self.config_values.config_filename_use_case, "r") - self.config_data = yaml.safe_load(self.config_file) - self.load_config() - except Exception: - _LOGGER.error("Could not load the environment configuration") - _LOGGER.error("Exception occured", exc_info=True) + with open(self._lay_down_config_path, "r") as file: + # Open the config file and build the environment laydown + self.config_data = yaml.safe_load(file) + self.load_lay_down_config() # Store the node objects as node attributes # (This is so we can access them as objects) @@ -179,12 +189,8 @@ class Primaite(Env): now = datetime.now() # current date and time time = now.strftime("%Y%m%d_%H%M%S") - path = "outputs/diagrams" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - filename = "outputs/diagrams/network_" + time + ".png" - plt.savefig(filename, format="PNG") + file_path = session_path / f"network_{timestamp_str}.png" + plt.savefig(file_path, format="PNG") plt.clf() except Exception: _LOGGER.error("Could not save network diagram") @@ -236,13 +242,9 @@ class Primaite(Env): time = now.strftime("%Y%m%d_%H%M%S") header = ["Episode", "Average Reward"] - # Check whether the output/rerults folder exists (doesn't exist by default install) - path = "outputs/results/" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - filename = "outputs/results/average_reward_per_episode_" + time + ".csv" - self.csv_file = open(filename, "w", encoding="UTF8", newline="") + file_name = f"average_reward_per_episode_{timestamp_str}.csv" + file_path = session_path / file_name + self.csv_file = open(file_path, "w", encoding="UTF8", newline="") self.csv_writer = csv.writer(self.csv_file) self.csv_writer.writerow(header) except Exception: @@ -404,7 +406,6 @@ class Primaite(Env): def __close__(self): """Override close function.""" self.csv_file.close() - self.config_file.close() def init_acl(self): """Initialise the Access Control List.""" @@ -888,7 +889,7 @@ class Primaite(Env): elif self.observation_type == ObservationType.MULTIDISCRETE: self._update_env_obs_multidiscrete() - def load_config(self): + def load_lay_down_config(self): """Loads config data in order to build the environment configuration.""" for item in self.config_data: if item["itemType"] == "NODE": @@ -918,15 +919,9 @@ class Primaite(Env): elif item["itemType"] == "PORTS": # Create the list of ports self.create_ports_list(item) - elif item["itemType"] == "ACTIONS": - # Get the action information - self.get_action_info(item) elif item["itemType"] == "OBSERVATIONS": # Get the observation information self.get_observation_info(item) - elif item["itemType"] == "STEPS": - # Get the steps information - self.get_steps_info(item) else: # Do nothing (bad formatting) pass @@ -1247,15 +1242,6 @@ class Primaite(Env): # Set the number of ports self.num_ports = len(self.ports_list) - def get_action_info(self, action_info): - """ - Extracts action_info. - - Args: - item: A config data item representing action info - """ - self.action_type = ActionType[action_info["type"]] - def get_observation_info(self, observation_info): """Extracts observation_info. @@ -1264,16 +1250,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_steps_info(self, steps_info): - """ - Extracts steps_info. - - Args: - item: A config data item representing steps info - """ - self.episode_steps = int(steps_info["steps"]) - _LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps") - def reset_environment(self): """ # Resets environment. diff --git a/src/primaite/main.py b/src/primaite/main.py index c963dd00..8a04852b 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,29 +1,41 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """ -Primaite - main (harness) module. +The main PrimAITE session runner module. -Coding Standards: PEP 8 +TODO: This will eventually be refactored out into a proper Session class. +TODO: The passing about of session_dir and timestamp_str is temporary and + will be cleaned up once we move to a proper Session class. """ - -import logging -import os.path +import argparse import time from datetime import datetime +from pathlib import Path +from typing import Final, Union -import yaml from stable_baselines3 import A2C, PPO from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.ppo import MlpPolicy as PPOMlp -from primaite.common.config_values_main import ConfigValuesMain +from primaite import SESSIONS_DIR, getLogger +from primaite.config.lay_down_config import data_manipulation_config_path +from primaite.config.training_config import TrainingConfig, \ + main_training_config_path from primaite.environment.primaite_env import Primaite from primaite.transactions.transactions_to_file import write_transaction_to_file -# FUNCTIONS # +_LOGGER = getLogger(__name__) -def run_generic(): - """Run against a generic agent.""" +def run_generic(env: Primaite, config_values: TrainingConfig): + """ + Run against a generic agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + """ for episode in range(0, config_values.num_episodes): env.reset() for step in range(0, config_values.num_steps): @@ -47,9 +59,24 @@ def run_generic(): env.close() -def run_stable_baselines3_ppo(): - """Run against a stable_baselines3 PPO agent.""" - if config_values.load_agent == True: +def run_stable_baselines3_ppo( + env: Primaite, + config_values: TrainingConfig, + session_path: Path, + timestamp_str: str +): + """ + Run against a stable_baselines3 PPO agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if config_values.load_agent: try: agent = PPO.load( config_values.agent_load_file, @@ -62,30 +89,44 @@ def run_stable_baselines3_ppo(): "ERROR: Could not load agent at location: " + config_values.agent_load_file ) - logging.error("Could not load agent") - logging.error("Exception occured", exc_info=True) + _LOGGER.error("Could not load agent") + _LOGGER.error("Exception occured", exc_info=True) else: agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) if config_values.session_type == "TRAINING": # We're in a training session print("Starting training session...") - logging.info("Starting training session...") + _LOGGER.debug("Starting training session...") for episode in range(0, config_values.num_episodes): agent.learn(total_timesteps=1) - save_agent(agent) + _save_agent(agent, session_path, timestamp_str) else: # Default to being in an evaluation session print("Starting evaluation session...") - logging.info("Starting evaluation session...") + _LOGGER.debug("Starting evaluation session...") evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) env.close() -def run_stable_baselines3_a2c(): - """Run against a stable_baselines3 A2C agent.""" - if config_values.load_agent == True: +def run_stable_baselines3_a2c( + env: Primaite, + config_values: TrainingConfig, + session_path: Path, timestamp_str: str +): + """ + Run against a stable_baselines3 A2C agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if config_values.load_agent: try: agent = A2C.load( config_values.agent_load_file, @@ -98,284 +139,151 @@ def run_stable_baselines3_a2c(): "ERROR: Could not load agent at location: " + config_values.agent_load_file ) - logging.error("Could not load agent") - logging.error("Exception occured", exc_info=True) + _LOGGER.error("Could not load agent") + _LOGGER.error("Exception occured", exc_info=True) else: agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps) if config_values.session_type == "TRAINING": # We're in a training session print("Starting training session...") - logging.info("Starting training session...") + _LOGGER.debug("Starting training session...") for episode in range(0, config_values.num_episodes): agent.learn(total_timesteps=1) - save_agent(agent) + _save_agent(agent, session_path, timestamp_str) else: # Default to being in an evaluation session print("Starting evaluation session...") - logging.info("Starting evaluation session...") + _LOGGER.debug("Starting evaluation session...") evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) env.close() -def save_agent(_agent): - """Persist an agent (only works for stable baselines3 agents at present).""" - now = datetime.now() # current date and time - time = now.strftime("%Y%m%d_%H%M%S") +def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): + """ + Persist an agent. - try: - path = "outputs/agents/" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - filename = "outputs/agents/agent_saved_" + time - _agent.save(filename) - logging.info("Trained agent saved as " + filename) - except Exception: - logging.error("Could not save agent") - logging.error("Exception occured", exc_info=True) + Only works for stable baselines3 agents at present. + + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if not isinstance(agent, OnPolicyAlgorithm): + msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." + _LOGGER.error(msg) + else: + filepath = session_path / f"agent_saved_{timestamp_str}" + agent.save(filepath) + _LOGGER.debug(f"Trained agent saved as: {filepath}") -def configure_logging(): - """Configures logging.""" - try: - now = datetime.now() # current date and time - time = now.strftime("%Y%m%d_%H%M%S") - filename = "logs/app_" + time + ".log" - path = "logs/" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - logging.basicConfig( - filename=filename, - filemode="w", - format="%(asctime)s - %(levelname)s - %(message)s", - datefmt="%d-%b-%y %H:%M:%S", - level=logging.INFO, - ) - except Exception: - print("ERROR: Could not start logging") +def _get_session_path(session_timestamp: datetime) -> Path: + """ + Get the directory path the session will output to. + + This is set in the format of: + ~/primaite/sessions//_. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = SESSIONS_DIR / date_dir / session_dir + session_path.mkdir(exist_ok=True, parents=True) + + return session_path -def load_config_values(): - """Loads the config values from the main config file into a config object.""" - try: - # Generic - config_values.agent_identifier = config_data["agentIdentifier"] - config_values.num_episodes = int(config_data["numEpisodes"]) - config_values.time_delay = int(config_data["timeDelay"]) - config_values.config_filename_use_case = ( - "config/" + config_data["configFilename"] +def run( + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path] +): + """Run the PrimAITE Session. + + :param training_config_path: The training config filepath. + :param lay_down_config_path: The lay down config filepath. + """ + # Welcome message + print("Welcome to the Primary-level AI Training Environment (PrimAITE)") + + session_timestamp: Final[datetime] = datetime.now() + session_path = _get_session_path(session_timestamp) + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + + print(f"The output directory for this session is: {session_path}") + + # Create a list of transactions + # A transaction is an object holding the: + # - episode # + # - step # + # - initial observation space + # - action + # - reward + # - new observation space + transaction_list = [] + + # Create the Primaite environment + env = Primaite( + training_config_path=training_config_path, + lay_down_config_path=lay_down_config_path, + transaction_list=transaction_list, + session_path=session_path, + timestamp_str=timestamp_str, + ) + + config_values = env.config_values + + # Get the number of steps (which is stored in the child config file) + config_values.num_steps = env.episode_steps + + # Run environment against an agent + if config_values.agent_identifier == "GENERIC": + run_generic(env=env, config_values=config_values) + elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": + run_stable_baselines3_ppo( + env=env, + config_values=config_values, + session_path=session_path, + timestamp_str=timestamp_str, ) - config_values.session_type = config_data["sessionType"] - config_values.load_agent = bool(config_data["loadAgent"]) - config_values.agent_load_file = config_data["agentLoadFile"] - # Environment - config_values.observation_space_high_value = int( - config_data["observationSpaceHighValue"] - ) - # Reward values - # Generic - config_values.all_ok = int(config_data["allOk"]) - # Node Hardware State - config_values.off_should_be_on = int(config_data["offShouldBeOn"]) - config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"]) - config_values.on_should_be_off = int(config_data["onShouldBeOff"]) - config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"]) - config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"]) - config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"]) - config_values.resetting = int(config_data["resetting"]) - # Node Software or Service State - config_values.good_should_be_patching = int(config_data["goodShouldBePatching"]) - config_values.good_should_be_compromised = int( - config_data["goodShouldBeCompromised"] - ) - config_values.good_should_be_overwhelmed = int( - config_data["goodShouldBeOverwhelmed"] - ) - config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"]) - config_values.patching_should_be_compromised = int( - config_data["patchingShouldBeCompromised"] - ) - config_values.patching_should_be_overwhelmed = int( - config_data["patchingShouldBeOverwhelmed"] - ) - config_values.patching = int(config_data["patching"]) - config_values.compromised_should_be_good = int( - config_data["compromisedShouldBeGood"] - ) - config_values.compromised_should_be_patching = int( - config_data["compromisedShouldBePatching"] - ) - config_values.compromised_should_be_overwhelmed = int( - config_data["compromisedShouldBeOverwhelmed"] - ) - config_values.compromised = int(config_data["compromised"]) - config_values.overwhelmed_should_be_good = int( - config_data["overwhelmedShouldBeGood"] - ) - config_values.overwhelmed_should_be_patching = int( - config_data["overwhelmedShouldBePatching"] - ) - config_values.overwhelmed_should_be_compromised = int( - config_data["overwhelmedShouldBeCompromised"] - ) - config_values.overwhelmed = int(config_data["overwhelmed"]) - # Node File System State - config_values.good_should_be_repairing = int( - config_data["goodShouldBeRepairing"] - ) - config_values.good_should_be_restoring = int( - config_data["goodShouldBeRestoring"] - ) - config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"]) - config_values.good_should_be_destroyed = int( - config_data["goodShouldBeDestroyed"] - ) - config_values.repairing_should_be_good = int( - config_data["repairingShouldBeGood"] - ) - config_values.repairing_should_be_restoring = int( - config_data["repairingShouldBeRestoring"] - ) - config_values.repairing_should_be_corrupt = int( - config_data["repairingShouldBeCorrupt"] - ) - config_values.repairing_should_be_destroyed = int( - config_data["repairingShouldBeDestroyed"] - ) - config_values.repairing = int(config_data["repairing"]) - config_values.restoring_should_be_good = int( - config_data["restoringShouldBeGood"] - ) - config_values.restoring_should_be_repairing = int( - config_data["restoringShouldBeRepairing"] - ) - config_values.restoring_should_be_corrupt = int( - config_data["restoringShouldBeCorrupt"] - ) - config_values.restoring_should_be_destroyed = int( - config_data["restoringShouldBeDestroyed"] - ) - config_values.restoring = int(config_data["restoring"]) - config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"]) - config_values.corrupt_should_be_repairing = int( - config_data["corruptShouldBeRepairing"] - ) - config_values.corrupt_should_be_restoring = int( - config_data["corruptShouldBeRestoring"] - ) - config_values.corrupt_should_be_destroyed = int( - config_data["corruptShouldBeDestroyed"] - ) - config_values.corrupt = int(config_data["corrupt"]) - config_values.destroyed_should_be_good = int( - config_data["destroyedShouldBeGood"] - ) - config_values.destroyed_should_be_repairing = int( - config_data["destroyedShouldBeRepairing"] - ) - config_values.destroyed_should_be_restoring = int( - config_data["destroyedShouldBeRestoring"] - ) - config_values.destroyed_should_be_corrupt = int( - config_data["destroyedShouldBeCorrupt"] - ) - config_values.destroyed = int(config_data["destroyed"]) - config_values.scanning = int(config_data["scanning"]) - # IER status - config_values.red_ier_running = int(config_data["redIerRunning"]) - config_values.green_ier_blocked = int(config_data["greenIerBlocked"]) - # Patching / Reset durations - config_values.os_patching_duration = int(config_data["osPatchingDuration"]) - config_values.node_reset_duration = int(config_data["nodeResetDuration"]) - config_values.service_patching_duration = int( - config_data["servicePatchingDuration"] - ) - config_values.file_system_repairing_limit = int( - config_data["fileSystemRepairingLimit"] - ) - config_values.file_system_restoring_limit = int( - config_data["fileSystemRestoringLimit"] - ) - config_values.file_system_scanning_limit = int( - config_data["fileSystemScanningLimit"] + elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": + run_stable_baselines3_a2c( + env=env, + config_values=config_values, + session_path=session_path, + timestamp_str=timestamp_str, ) - logging.info("Training agent: " + config_values.agent_identifier) - logging.info( - "Training environment config: " + config_values.config_filename_use_case + print("Session finished") + _LOGGER.debug("Session finished") + + print("Saving transaction logs...") + _LOGGER.debug("Saving transaction logs...") + + write_transaction_to_file( + transaction_list=transaction_list, + session_path=session_path, + timestamp_str=timestamp_str, + ) + + print("Finished") + _LOGGER.debug("Finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tc") + parser.add_argument("--ldc") + args = parser.parse_args() + if not args.tc: + _LOGGER.error( + "Please provide a training config file using the --tc " "argument" ) - logging.info( - "Training cycle has " + str(config_values.num_episodes) + " episodes" + if not args.ldc: + _LOGGER.error( + "Please provide a lay down config file using the --ldc " "argument" ) - - except Exception: - logging.error("Could not save load config data") - logging.error("Exception occured", exc_info=True) - - -# MAIN PROCESS # - -# Starting point - -# Welcome message -print("Welcome to the Primary-level AI Training Environment (PrimAITE)") - -# Configure logging -configure_logging() - -# Open the main config file -try: - config_file_main = open("config/config_main.yaml", "r") - config_data = yaml.safe_load(config_file_main) - # Create a config class - config_values = ConfigValuesMain() - # Load in config data - load_config_values() -except Exception: - logging.error("Could not load main config") - logging.error("Exception occured", exc_info=True) - -# Create a list of transactions -# A transaction is an object holding the: -# - episode # -# - step # -# - initial observation space -# - action -# - reward -# - new observation space -transaction_list = [] - -# Create the Primaite environment -# try: -env = Primaite(config_values, transaction_list) -# logging.info("PrimAITE environment created") -# except Exception: -# logging.error("Could not create PrimAITE environment") -# logging.error("Exception occured", exc_info=True) - -# Get the number of steps (which is stored in the child config file) -config_values.num_steps = env.episode_steps - -# Run environment against an agent -if config_values.agent_identifier == "GENERIC": - run_generic() -elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": - run_stable_baselines3_ppo() -elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": - run_stable_baselines3_a2c() - -print("Session finished") -logging.info("Session finished") - -print("Saving transaction logs...") -logging.info("Saving transaction logs...") - -write_transaction_to_file(transaction_list) - -config_file_main.close() - -print("Finished") -logging.info("Finished") + run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index 83b7ab9f..54b0c642 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -3,7 +3,6 @@ import logging from typing import Final -from primaite.common.config_values_main import ConfigValuesMain from primaite.common.enums import ( FileSystemState, HardwareState, @@ -11,6 +10,7 @@ from primaite.common.enums import ( Priority, SoftwareState, ) +from primaite.config.training_config import TrainingConfig from primaite.nodes.node import Node _LOGGER: Final[logging.Logger] = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class ActiveNode(Node): ip_address: str, software_state: SoftwareState, file_system_state: FileSystemState, - config_values: ConfigValuesMain, + config_values: TrainingConfig, ): """ Init. diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 449ceb50..8fd69c78 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -2,8 +2,8 @@ """The base Node class.""" from typing import Final -from primaite.common.config_values_main import ConfigValuesMain from primaite.common.enums import HardwareState, NodeType, Priority +from primaite.config.training_config import TrainingConfig class Node: @@ -16,7 +16,7 @@ class Node: node_type: NodeType, priority: Priority, hardware_state: HardwareState, - config_values: ConfigValuesMain, + config_values: TrainingConfig, ): """ Init. @@ -34,7 +34,7 @@ class Node: self.priority = priority self.hardware_state: HardwareState = hardware_state self.resetting_count: int = 0 - self.config_values: ConfigValuesMain = config_values + self.config_values: TrainingConfig = config_values def __repr__(self): """Returns the name of the node.""" diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index c6980e4c..6515097a 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """The Passive Node class (i.e. an actuator).""" -from primaite.common.config_values_main import ConfigValuesMain from primaite.common.enums import HardwareState, NodeType, Priority +from primaite.config.training_config import TrainingConfig from primaite.nodes.node import Node @@ -15,7 +15,7 @@ class PassiveNode(Node): node_type: NodeType, priority: Priority, hardware_state: HardwareState, - config_values: ConfigValuesMain, + config_values: TrainingConfig, ): """ Init. diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 161f9249..5e20f783 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -3,7 +3,6 @@ import logging from typing import Dict, Final -from primaite.common.config_values_main import ConfigValuesMain from primaite.common.enums import ( FileSystemState, HardwareState, @@ -12,6 +11,7 @@ from primaite.common.enums import ( SoftwareState, ) from primaite.common.service import Service +from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode _LOGGER: Final[logging.Logger] = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class ServiceNode(ActiveNode): ip_address: str, software_state: SoftwareState, file_system_state: FileSystemState, - config_values: ConfigValuesMain, + config_values: TrainingConfig, ): """ Init. diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py new file mode 100644 index 00000000..d8c93b9a --- /dev/null +++ b/src/primaite/notebooks/__init__.py @@ -0,0 +1,36 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import importlib.util +import os +import subprocess +import sys + +from primaite import NOTEBOOKS_DIR, getLogger + +_LOGGER = getLogger(__name__) + + +def start_jupyter_session(): + """ + Starts a new Jupyter notebook session in the app notebooks directory. + + Currently only works on Windows OS. + + .. todo:: Figure out how to get this working for Linux and MacOS too. + """ + if sys.platform == "win32": + if importlib.util.find_spec("jupyter") is not None: + # Jupyter is installed + working_dir = os.getcwd() + os.chdir(NOTEBOOKS_DIR) + subprocess.Popen("jupyter lab") + os.chdir(working_dir) + else: + # Jupyter is not installed + _LOGGER.error("Cannot start jupyter lab as it is not installed") + else: + msg = ( + "Feature currently only supported on Windows OS. For " + "Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter " + "lab' from your Python environment." + ) + _LOGGER.warning(msg) diff --git a/src/primaite/setup/__init__.py b/src/primaite/setup/__init__.py new file mode 100644 index 00000000..63f825c2 --- /dev/null +++ b/src/primaite/setup/__init__.py @@ -0,0 +1 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. diff --git a/src/primaite/setup/old_installation_clean_up.py b/src/primaite/setup/old_installation_clean_up.py new file mode 100644 index 00000000..292535f2 --- /dev/null +++ b/src/primaite/setup/old_installation_clean_up.py @@ -0,0 +1,13 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def run(): + """Perform the full clean-up.""" + pass + + +if __name__ == "__main__": + run() diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py new file mode 100644 index 00000000..5192c48f --- /dev/null +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -0,0 +1,39 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import filecmp +import os +import shutil +from pathlib import Path + +import pkg_resources + +from primaite import NOTEBOOKS_DIR, getLogger + +_LOGGER = getLogger(__name__) + + +def run(overwrite_existing: bool = True): + """ + Resets the demo jupyter notebooks in the users app notebooks directory. + + :param overwrite_existing: A bool to toggle replacing existing edited + notebooks on or off. + """ + notebooks_package_data_root = pkg_resources.resource_filename( + "primaite", "notebooks/_package_data" + ) + for subdir, dirs, files in os.walk(notebooks_package_data_root): + for file in files: + fp = os.path.join(subdir, file) + path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep) + target_fp = NOTEBOOKS_DIR / Path(*path_split) + target_fp.parent.mkdir(exist_ok=True, parents=True) + copy_file = not target_fp.is_file() + + if overwrite_existing and not copy_file: + copy_file = (not filecmp.cmp(fp, target_fp)) and ( + ".ipynb_checkpoints" not in str(target_fp) + ) + + if copy_file: + shutil.copy2(fp, target_fp) + _LOGGER.info(f"Reset example notebook: {target_fp}") diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py new file mode 100644 index 00000000..f4166c6a --- /dev/null +++ b/src/primaite/setup/reset_example_configs.py @@ -0,0 +1,37 @@ +import filecmp +import os +import shutil +from pathlib import Path + +import pkg_resources + +from primaite import USERS_CONFIG_DIR, getLogger + +_LOGGER = getLogger(__name__) + + +def run(overwrite_existing=True): + """ + Resets the example config files in the users app config directory. + + :param overwrite_existing: A bool to toggle replacing existing edited + config on or off. + """ + configs_package_data_root = pkg_resources.resource_filename( + "primaite", "config/_package_data" + ) + + for subdir, dirs, files in os.walk(configs_package_data_root): + for file in files: + fp = os.path.join(subdir, file) + path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep) + target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split) + target_fp.parent.mkdir(exist_ok=True, parents=True) + copy_file = not target_fp.is_file() + + if overwrite_existing and not copy_file: + copy_file = not filecmp.cmp(fp, target_fp) + + if copy_file: + shutil.copy2(fp, target_fp) + _LOGGER.info(f"Reset example config: {target_fp}") diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py new file mode 100644 index 00000000..9f6e8a13 --- /dev/null +++ b/src/primaite/setup/setup_app_dirs.py @@ -0,0 +1,27 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger + +_LOGGER = getLogger(__name__) + + +def run(): + """ + Handles creation of application directories and user directories. + + Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required + app directories in the correct locations based on the users OS. + """ + app_dirs = [ + _USER_DIRS, + NOTEBOOKS_DIR, + LOG_DIR, + ] + + for app_dir in app_dirs: + if not app_dir.is_dir(): + app_dir.mkdir(parents=True, exist_ok=True) + _LOGGER.info(f"Created directory: {app_dir}") + + +if __name__ == "__main__": + run() diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 7a6e212c..f5508bb2 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -2,9 +2,11 @@ """Writes the Transaction log list out to file for evaluation to utilse.""" import csv -import logging -import os.path -from datetime import datetime +from pathlib import Path + +from primaite import getLogger + +_LOGGER = getLogger(__name__) def turn_action_space_to_array(_action_space): @@ -38,18 +40,22 @@ def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): return return_array -def write_transaction_to_file(_transaction_list): +def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): """ Writes transaction logs to file to support training evaluation. - Args: - _transaction_list: The list of transactions from all steps and all episodes - _num_episodes: The number of episodes that were conducted. + :param transaction_list: The list of transactions from all steps and all + episodes. + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. """ - # Get the first transaction and use it to determine the makeup of the observation space and action space - # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1" + # Get the first transaction and use it to determine the makeup of the + # observation space and action space + # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action + # space as "AS_1" # This will be tied into the PrimAITE Use Case so that they make sense - template_transation = _transaction_list[0] + template_transation = transaction_list[0] action_length = template_transation.action_space.size obs_shape = template_transation.obs_space_post.shape obs_assets = template_transation.obs_space_post.shape[0] @@ -75,21 +81,15 @@ def write_transaction_to_file(_transaction_list): # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] header = header + action_header + obs_header_initial + obs_header_new - now = datetime.now() # current date and time - time = now.strftime("%Y%m%d_%H%M%S") + try: - path = "outputs/results/" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - - filename = "outputs/results/all_transactions_" + time + ".csv" + filename = session_path / f"all_transactions_{timestamp_str}.csv" csv_file = open(filename, "w", encoding="UTF8", newline="") csv_writer = csv.writer(csv_file) csv_writer.writerow(header) - for transaction in _transaction_list: + for transaction in transaction_list: csv_data = [ str(transaction.timestamp), str(transaction.episode_number), @@ -110,5 +110,4 @@ def write_transaction_to_file(_transaction_list): csv_file.close() except Exception: - logging.error("Could not save the transaction file") - logging.error("Exception occured", exc_info=True) + _LOGGER.error("Could not save the transaction file", exc_info=True) diff --git a/src/primaite/utils/__init__.py b/src/primaite/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py new file mode 100644 index 00000000..59f36851 --- /dev/null +++ b/src/primaite/utils/package_data.py @@ -0,0 +1,32 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import os +from pathlib import Path + +import pkg_resources + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def get_file_path(path: str) -> Path: + """ + Get PrimAITE package data. + + :Example: + + >>> from primaite.utils.package_data import get_file_path + >>> main_env_config = get_file_path("config/_package_data/training_config_main.yaml") + + + :param path: The path from the primaite root. + :return: The file path of the package data file. + :raise FileNotFoundError: When the filepath does not exist. + """ + fp = pkg_resources.resource_filename("primaite", path) + if os.path.isfile(fp): + return Path(fp) + else: + msg = f"Cannot PrimAITE package data: {fp}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) diff --git a/tests/config/legacy/__init__.py b/tests/config/legacy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/config/config_main.yaml b/tests/config/legacy/legacy_training_config.yaml similarity index 100% rename from src/primaite/config/config_main.yaml rename to tests/config/legacy/legacy_training_config.yaml diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml new file mode 100644 index 00000000..becc1799 --- /dev/null +++ b/tests/config/legacy/new_training_config.yaml @@ -0,0 +1,94 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: ANY +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index 00f8016e..70458275 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,7 +1,3 @@ -- itemType: ACTIONS - type: NODE -- itemType: STEPS - steps: 15 - itemType: PORTS portsList: - port: '21' diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index 603d03dc..2e752bc9 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -4,86 +4,91 @@ # Choose one of these (dependent on Agent being trained) # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" -# "GENERIC" -agentIdentifier: GENERIC +agent_identifier: GENERIC +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE # Number of episodes to run per session -numEpisodes: 1 +num_episodes: 1 +# Number of time_steps per episode +num_steps: 15 # Time delay between steps (for generic agents) -timeDelay: 1 -# Filename of the scenario / laydown -configFilename: one_node_states_on_off_lay_down_config.yaml +time_delay: 1 + # Type of session to be run (TRAINING or EVALUATION) -sessionType: TRAINING +session_type: TRAINING # Determine whether to load an agent from file -loadAgent: False +load_agent: False # File path and file name of agent if you're loading one in -agentLoadFile: C:\[Path]\[agent_saved_filename.zip] +agent_load_file: C:\[Path]\[agent_saved_filename.zip] # Environment config values # The high value for the observation space -observationSpaceHighValue: 1000000000 +observation_space_high_value: 1000000000 # Reward values # Generic -allOk: 0 +all_ok: 0 # Node Hardware State -offShouldBeOn: -10 -offShouldBeResetting: -5 -onShouldBeOff: -2 -onShouldBeResetting: -5 -resettingShouldBeOn: -5 -resettingShouldBeOff: -2 +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 resetting: -3 # Node Software or Service State -goodShouldBePatching: 2 -goodShouldBeCompromised: 5 -goodShouldBeOverwhelmed: 5 -patchingShouldBeGood: -5 -patchingShouldBeCompromised: 2 -patchingShouldBeOverwhelmed: 2 +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 patching: -3 -compromisedShouldBeGood: -20 -compromisedShouldBePatching: -20 -compromisedShouldBeOverwhelmed: -20 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 compromised: -20 -overwhelmedShouldBeGood: -20 -overwhelmedShouldBePatching: -20 -overwhelmedShouldBeCompromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 overwhelmed: -20 # Node File System State -goodShouldBeRepairing: 2 -goodShouldBeRestoring: 2 -goodShouldBeCorrupt: 5 -goodShouldBeDestroyed: 10 -repairingShouldBeGood: -5 -repairingShouldBeRestoring: 2 -repairingShouldBeCorrupt: 2 -repairingShouldBeDestroyed: 0 +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 repairing: -3 -restoringShouldBeGood: -10 -restoringShouldBeRepairing: -2 -restoringShouldBeCorrupt: 1 -restoringShouldBeDestroyed: 2 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 restoring: -6 -corruptShouldBeGood: -10 -corruptShouldBeRepairing: -10 -corruptShouldBeRestoring: -10 -corruptShouldBeDestroyed: 2 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 corrupt: -10 -destroyedShouldBeGood: -20 -destroyedShouldBeRepairing: -20 -destroyedShouldBeRestoring: -20 -destroyedShouldBeCorrupt: -20 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 destroyed: -20 scanning: -2 # IER status -redIerRunning: -5 -greenIerBlocked: -10 +red_ier_running: -5 +green_ier_blocked: -10 # Patching / Reset durations -osPatchingDuration: 5 # The time taken to patch the OS -nodeResetDuration: 5 # The time taken to reset a node (hardware) -servicePatchingDuration: 5 # The time taken to patch a service -fileSystemRepairingLimit: 5 # The time take to repair the file system -fileSystemRestoringLimit: 5 # The time take to restore the file system -fileSystemScanningLimit: 5 # The time taken to scan the file system +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/conftest.py b/tests/conftest.py index 1e987223..93c2359b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,6 @@ import time from pathlib import Path from typing import Union -import yaml - -from primaite.common.config_values_main import ConfigValuesMain from primaite.environment.primaite_env import Primaite ACTION_SPACE_NODE_VALUES = 1 @@ -13,159 +10,19 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 def _get_primaite_env_from_config( - main_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] + training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] ): """Takes a config path and returns the created instance of Primaite.""" - - def load_config_values(): - config_values.agent_identifier = config_data["agentIdentifier"] - config_values.num_episodes = int(config_data["numEpisodes"]) - config_values.time_delay = int(config_data["timeDelay"]) - config_values.config_filename_use_case = lay_down_config_path - config_values.session_type = config_data["sessionType"] - config_values.load_agent = bool(config_data["loadAgent"]) - config_values.agent_load_file = config_data["agentLoadFile"] - # Environment - config_values.observation_space_high_value = int( - config_data["observationSpaceHighValue"] - ) - # Reward values - # Generic - config_values.all_ok = int(config_data["allOk"]) - # Node Hardware State - config_values.off_should_be_on = int(config_data["offShouldBeOn"]) - config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"]) - config_values.on_should_be_off = int(config_data["onShouldBeOff"]) - config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"]) - config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"]) - config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"]) - config_values.resetting = int(config_data["resetting"]) - # Node Software or Service State - config_values.good_should_be_patching = int(config_data["goodShouldBePatching"]) - config_values.good_should_be_compromised = int( - config_data["goodShouldBeCompromised"] - ) - config_values.good_should_be_overwhelmed = int( - config_data["goodShouldBeOverwhelmed"] - ) - config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"]) - config_values.patching_should_be_compromised = int( - config_data["patchingShouldBeCompromised"] - ) - config_values.patching_should_be_overwhelmed = int( - config_data["patchingShouldBeOverwhelmed"] - ) - config_values.patching = int(config_data["patching"]) - config_values.compromised_should_be_good = int( - config_data["compromisedShouldBeGood"] - ) - config_values.compromised_should_be_patching = int( - config_data["compromisedShouldBePatching"] - ) - config_values.compromised_should_be_overwhelmed = int( - config_data["compromisedShouldBeOverwhelmed"] - ) - config_values.compromised = int(config_data["compromised"]) - config_values.overwhelmed_should_be_good = int( - config_data["overwhelmedShouldBeGood"] - ) - config_values.overwhelmed_should_be_patching = int( - config_data["overwhelmedShouldBePatching"] - ) - config_values.overwhelmed_should_be_compromised = int( - config_data["overwhelmedShouldBeCompromised"] - ) - config_values.overwhelmed = int(config_data["overwhelmed"]) - # Node File System State - config_values.good_should_be_repairing = int( - config_data["goodShouldBeRepairing"] - ) - config_values.good_should_be_restoring = int( - config_data["goodShouldBeRestoring"] - ) - config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"]) - config_values.good_should_be_destroyed = int( - config_data["goodShouldBeDestroyed"] - ) - config_values.repairing_should_be_good = int( - config_data["repairingShouldBeGood"] - ) - config_values.repairing_should_be_restoring = int( - config_data["repairingShouldBeRestoring"] - ) - config_values.repairing_should_be_corrupt = int( - config_data["repairingShouldBeCorrupt"] - ) - config_values.repairing_should_be_destroyed = int( - config_data["repairingShouldBeDestroyed"] - ) - config_values.repairing = int(config_data["repairing"]) - config_values.restoring_should_be_good = int( - config_data["restoringShouldBeGood"] - ) - config_values.restoring_should_be_repairing = int( - config_data["restoringShouldBeRepairing"] - ) - config_values.restoring_should_be_corrupt = int( - config_data["restoringShouldBeCorrupt"] - ) - config_values.restoring_should_be_destroyed = int( - config_data["restoringShouldBeDestroyed"] - ) - config_values.restoring = int(config_data["restoring"]) - config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"]) - config_values.corrupt_should_be_repairing = int( - config_data["corruptShouldBeRepairing"] - ) - config_values.corrupt_should_be_restoring = int( - config_data["corruptShouldBeRestoring"] - ) - config_values.corrupt_should_be_destroyed = int( - config_data["corruptShouldBeDestroyed"] - ) - config_values.corrupt = int(config_data["corrupt"]) - config_values.destroyed_should_be_good = int( - config_data["destroyedShouldBeGood"] - ) - config_values.destroyed_should_be_repairing = int( - config_data["destroyedShouldBeRepairing"] - ) - config_values.destroyed_should_be_restoring = int( - config_data["destroyedShouldBeRestoring"] - ) - config_values.destroyed_should_be_corrupt = int( - config_data["destroyedShouldBeCorrupt"] - ) - config_values.destroyed = int(config_data["destroyed"]) - config_values.scanning = int(config_data["scanning"]) - # IER status - config_values.red_ier_running = int(config_data["redIerRunning"]) - config_values.green_ier_blocked = int(config_data["greenIerBlocked"]) - # Patching / Reset durations - config_values.os_patching_duration = int(config_data["osPatchingDuration"]) - config_values.node_reset_duration = int(config_data["nodeResetDuration"]) - config_values.service_patching_duration = int( - config_data["servicePatchingDuration"] - ) - config_values.file_system_repairing_limit = int( - config_data["fileSystemRepairingLimit"] - ) - config_values.file_system_restoring_limit = int( - config_data["fileSystemRestoringLimit"] - ) - config_values.file_system_scanning_limit = int( - config_data["fileSystemScanningLimit"] - ) - - config_file_main = open(main_config_path, "r") - config_data = yaml.safe_load(config_file_main) - # Create a config class - config_values = ConfigValuesMain() - # Load in config data - load_config_values() - env = Primaite(config_values, []) + env = Primaite( + training_config_path=training_config_path, + lay_down_config_path=lay_down_config_path, + transaction_list=[], + ) + config_values = env.config_values config_values.num_steps = env.episode_steps + # TOOD: This needs t be refactored to happen outside. Should be part of + # a main Session class. if env.config_values.agent_identifier == "GENERIC": run_generic(env, config_values) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 6a187761..6ecc5c1b 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -7,7 +7,8 @@ from tests.conftest import _get_primaite_env_from_config def test_creating_env_with_box_obs(): """Try creating env with box observation space.""" env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", ) env.update_environent_obs() @@ -22,7 +23,8 @@ def test_creating_env_with_box_obs(): def test_creating_env_with_multidiscrete_obs(): """Try creating env with MultiDiscrete observation space.""" env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "multidiscrete_obs_space_laydown_config.yaml", ) diff --git a/tests/test_reward.py b/tests/test_reward.py index 4925a434..c3fcdfc4 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -9,7 +9,8 @@ def test_rewards_are_being_penalised_at_each_step_function(): When the initial state is OFF compared to reference state which is ON. """ env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", ) diff --git a/tests/test_training_config.py b/tests/test_training_config.py new file mode 100644 index 00000000..9806a566 --- /dev/null +++ b/tests/test_training_config.py @@ -0,0 +1,36 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import yaml + +from primaite.config import training_config +from tests import TEST_CONFIG_ROOT + + +def test_legacy_lay_down_config_yaml_conversion(): + """Tests the conversion of legacy lay down config files.""" + legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + + with open(legacy_path, "r") as file: + legacy_dict = yaml.safe_load(file) + + with open(new_path, "r") as file: + new_dict = yaml.safe_load(file) + + converted_dict = training_config.convert_legacy_training_config_dict( + legacy_dict) + + assert converted_dict == new_dict + + +def test_create_config_values_main_from_file(): + """Tests creating an instance of TrainingConfig from file.""" + new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + + training_config.load(new_path) + + +def test_create_config_values_main_from_legacy_file(): + """Tests creating an instance of TrainingConfig from legacy file.""" + new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" + + training_config.load(new_path, legacy_file=True)