diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..575611dc --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,90 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [2.0.0] - 2023-07-31 + +### Added +- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE. +- Application Directories to enable PrimAITE as a Python package with predefined directories for storage. +- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib. +- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`. +- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options. +- Session loading to revisit previously run sessions for SB3 Agents. +- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface. +- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs: + 1. Session Metadata + 2. Results + 3. Diagrams + 4. Saved agents (training checkpoints and a final trained agent). +- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup. +- Benchmarking of PrimAITE performance, showcasing session and step durations for reference. +- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`. + +### Changed +- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions. +- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`. +- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names. +- Docs and Tests now sit outside the `src/` directory. +- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages. +- All dependencies are now defined in the `pyproject.toml` file. +- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each. +- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management. +- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations. +- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority. + + +### Fixed +- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation. +- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes. +- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands. +- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making. + +## [1.1.1] - 2023-06-27 + +### Bug Fixes +* Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by: + * Implementing a reference copy of all green IERs (`self.green_iers_reference`). + * Clearing the traffic on reference IERs at the same time as the live IERs. + * Passing the `green_iers_reference` to the `apply_iers` function at the reference stage. + * Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`. + * Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment. + * Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes. + * Removing the unnecessary "Reapply PoL and IERs" action from the step function. + * Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function. + +## [1.1.0] - 2023-03-13 + +### Added +* The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function. +* The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name. +* Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components: + * Blue agent observation space; + * Blue agent action space; + * Reward function; + * Node pattern-of-life. +* The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network). +* New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile. +* NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability. +* The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value. +* The User Guide has been updated to reflect all the above changes. + +### Changed +* "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing. +* "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement. +* "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement. +* "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file. +* Updates to Transactions. + +### Fixed +* Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default. + + + +[unreleased]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/compare/v2.0.0...HEAD +[2.0.0]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/tag/v2.0.0 diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..93d6f98b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 - 2025 Defence Science and Technology Laboratory UK (https://dstl.gov.uk) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/diagram/classes.puml b/diagram/classes.puml new file mode 100644 index 00000000..4505f3e0 --- /dev/null +++ b/diagram/classes.puml @@ -0,0 +1,521 @@ +@startuml classes +set namespaceSeparator none +class "ACLRule" as primaite.acl.acl_rule.ACLRule { + dest_ip : str + permission + port : str + protocol : str + source_ip : str + get_dest_ip() -> str + get_permission() -> str + get_port() -> str + get_protocol() -> str + get_source_ip() -> str +} +class "AbstractObservationComponent" as primaite.environment.observations.AbstractObservationComponent { + current_observation : NotImplementedType, ndarray + env : str + space : Space + structure : List[str] + {abstract}generate_structure() -> List[str] + {abstract}update() -> None +} +class "AccessControlList" as primaite.acl.access_control_list.AccessControlList { + acl + acl_implicit_permission + acl_implicit_rule + max_acl_rules : int + add_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: str) -> None + check_address_match(_rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool + get_dictionary_hash(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int + get_relevant_rules(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> Dict[int, ACLRule] + is_blocked(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool + remove_all_rules() -> None + remove_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None +} +class "AccessControlList_" as primaite.environment.observations.AccessControlList_ { + current_observation : ndarray + space : MultiDiscrete + structure : list + generate_structure() -> List[str] + update() -> None +} + +class "ActiveNode" as primaite.nodes.active_node.ActiveNode { + file_system_action_count : int + file_system_scanning : bool + file_system_scanning_count : int + file_system_state_actual : GOOD + file_system_state_observed : REPAIRING, RESTORING, GOOD + ip_address : str + patching_count : int + software_state + software_state : GOOD + set_file_system_state(file_system_state: FileSystemState) -> None + set_file_system_state_if_not_compromised(file_system_state: FileSystemState) -> None + set_software_state_if_not_compromised(software_state: SoftwareState) -> None + start_file_system_scan() -> None + update_booting_status() -> None + update_file_system_state() -> None + update_os_patching_status() -> None + update_resetting_status() -> None +} +class "AgentSessionABC" as primaite.agents.agent_abc.AgentSessionABC { + checkpoints_path + evaluation_path + is_eval : bool + learning_path + sb3_output_verbose_level : NONE + session_path : Union[str, Path] + session_timestamp : datetime + timestamp_str + uuid + close() -> None + {abstract}evaluate() -> None + {abstract}export() -> None + {abstract}learn() -> None + load(path: Union[str, Path]) -> None + {abstract}save() -> None +} + +class "DoNothingACLAgent" as primaite.agents.simple.DoNothingACLAgent { +} +class "DoNothingNodeAgent" as primaite.agents.simple.DoNothingNodeAgent { +} +class "DummyAgent" as primaite.agents.simple.DummyAgent { +} +class "HardCodedACLAgent" as primaite.agents.hardcoded_acl.HardCodedACLAgent { + get_allow_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule] + get_allow_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule] + get_blocked_green_iers(green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[str, IER] + get_blocking_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule] + get_deny_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule] + get_matching_acl_rules(source_node_id: str, dest_node_id: str, protocol: str, port: str, acl: AccessControlList, nodes: Dict[str, Union[ServiceNode, ActiveNode]], services_list: List[str]) -> Dict[int, ACLRule] + get_matching_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule] +} +class "HardCodedAgentSessionABC" as primaite.agents.hardcoded_abc.HardCodedAgentSessionABC { + is_eval : bool + evaluate() -> None + export() -> None + learn() -> None + load(path: Union[str, Path]) -> None + save() -> None +} +class "HardCodedNodeAgent" as primaite.agents.hardcoded_node.HardCodedNodeAgent { +} +class "IER" as primaite.pol.ier.IER { + dest_node_id : str + end_step : int + id : str + load : int + mission_criticality : int + port : str + protocol : str + running : bool + source_node_id : str + start_step : int + get_dest_node_id() -> str + get_end_step() -> int + get_id() -> str + get_is_running() -> bool + get_load() -> int + get_mission_criticality() -> int + get_port() -> str + get_protocol() -> str + get_source_node_id() -> str + get_start_step() -> int + set_is_running(_value: bool) -> None +} +class "Link" as primaite.links.link.Link { + bandwidth : int + dest_node_name : str + id : str + protocol_list : List[Protocol] + source_node_name : str + add_protocol(_protocol: str) -> None + add_protocol_load(_protocol: str, _load: int) -> None + clear_traffic() -> None + get_bandwidth() -> int + get_current_load() -> int + get_dest_node_name() -> str + get_id() -> str + get_protocol_list() -> List[Protocol] + get_source_node_name() -> str +} +class "LinkTrafficLevels" as primaite.environment.observations.LinkTrafficLevels { + current_observation : ndarray + space : MultiDiscrete + structure : list + generate_structure() -> List[str] + update() -> None +} +class "Node" as primaite.nodes.node.Node { + booting_count : int + config_values + hardware_state : BOOTING, ON, RESETTING, OFF + name : Final[str] + node_id : Final[str] + node_type : Final[NodeType] + priority + resetting_count : int + shutting_down_count : int + reset() -> None + turn_off() -> None + turn_on() -> None + update_booting_status() -> None + update_resetting_status() -> None + update_shutdown_status() -> None +} +class "NodeLinkTable" as primaite.environment.observations.NodeLinkTable { + current_observation : ndarray + space : Box + structure : list + generate_structure() -> List[str] + update() -> None +} +class "NodeStateInstructionGreen" as primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen { + end_step : int + id : str + node_id : str + node_pol_type : str + service_name : str + start_step : int + state : Union['HardwareState', 'SoftwareState', 'FileSystemState'] + get_end_step() -> int + get_node_id() -> str + get_node_pol_type() -> 'NodePOLType' + get_service_name() -> str + get_start_step() -> int + get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState'] +} +class "NodeStateInstructionRed" as primaite.nodes.node_state_instruction_red.NodeStateInstructionRed { + end_step : int + id : str + initiator : str + pol_type + service_name : str + source_node_id : str + source_node_service : str + source_node_service_state : str + start_step : int + state : Union['HardwareState', 'SoftwareState', 'FileSystemState'] + target_node_id : str + get_end_step() -> int + get_initiator() -> 'NodePOLInitiator' + get_pol_type() -> NodePOLType + get_service_name() -> str + get_source_node_id() -> str + get_source_node_service() -> str + get_source_node_service_state() -> str + get_start_step() -> int + get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState'] + get_target_node_id() -> str +} +class "NodeStatuses" as primaite.environment.observations.NodeStatuses { + current_observation : ndarray + space : MultiDiscrete + structure : list + generate_structure() -> List[str] + update() -> None +} +class "ObservationsHandler" as primaite.environment.observations.ObservationsHandler { + current_observation + registered_obs_components : List[AbstractObservationComponent] + space + deregister(obs_component: AbstractObservationComponent) -> None + describe_structure() -> List[str] + from_config(env: 'Primaite', obs_space_config: dict) -> 'ObservationsHandler' + register(obs_component: AbstractObservationComponent) -> None + update_obs() -> None + update_space() -> None +} +class "PassiveNode" as primaite.nodes.passive_node.PassiveNode { + ip_address +} +class "Primaite" as primaite.environment.primaite_env.Primaite { + ACTION_SPACE_ACL_ACTION_VALUES : int + ACTION_SPACE_ACL_PERMISSION_VALUES : int + ACTION_SPACE_NODE_ACTION_VALUES : int + ACTION_SPACE_NODE_PROPERTY_VALUES : int + acl + action_dict : dict, Dict[int, List[int]] + action_space : Discrete, Space + action_type : int + actual_episode_count + agent_identifier + average_reward : float + env_obs : ndarray, tuple + episode_av_reward_writer + episode_count : int + episode_steps : int + green_iers : Dict[str, IER] + green_iers_reference : Dict[str, IER] + lay_down_config + links : Dict[str, Link] + links_post_blue : dict + links_post_pol : dict + links_post_red : dict + links_reference : Dict[str, Link] + max_number_acl_rules : int + network : Graph + network_reference : Graph + node_pol : Dict[str, NodeStateInstructionGreen] + nodes : Dict[str, NodeUnion] + nodes_post_blue : dict + nodes_post_pol : dict + nodes_post_red : dict + nodes_reference : Dict[str, NodeUnion] + num_links : int + num_nodes : int + num_ports : int + num_services : int + obs_config : dict + obs_handler + observation_space : Tuple, Box, Space + observation_type + ports_list : List[str] + red_iers : Dict[str, IER], dict + red_node_pol : dict, Dict[str, NodeStateInstructionRed] + services_list : List[str] + session_path : Final[Path] + step_count : int + step_info : Dict[Any] + timestamp_str : Final[str] + total_reward : float + total_step_count : int + training_config + transaction_writer + apply_actions_to_acl(_action: int) -> None + apply_actions_to_nodes(_action: int) -> None + apply_time_based_updates() -> None + close() -> None + create_acl_action_dict() -> Dict[int, List[int]] + create_acl_rule(item: Dict) -> None + create_green_ier(item: Dict) -> None + create_green_pol(item: Dict) -> None + create_link(item: Dict) -> None + create_node(item: Dict) -> None + create_node_action_dict() -> Dict[int, List[int]] + create_node_and_acl_action_dict() -> Dict[int, List[int]] + create_ports_list(ports: Dict) -> None + create_red_ier(item: Dict) -> None + create_red_pol(item: Dict) -> None + create_services_list(services: Dict) -> None + get_action_info(action_info: Dict) -> None + get_observation_info(observation_info: Dict) -> None + init_acl() -> None + init_observations() -> Tuple[spaces.Space, np.ndarray] + interpret_action_and_apply(_action: int) -> None + load_lay_down_config() -> None + output_link_status() -> None + reset() -> np.ndarray + reset_environment() -> None + reset_node(item: Dict) -> None + save_obs_config(obs_config: dict) -> None + set_as_eval() -> None + step(action: int) -> Tuple[np.ndarray, float, bool, Dict] + update_environent_obs() -> None +} +class "PrimaiteSession" as primaite.primaite_session.PrimaiteSession { + evaluation_path : Optional[Path], Path + is_load_session : bool + learning_path : Optional[Path], Path + session_path : Optional[Path], Path + timestamp_str : str, Optional[str] + close() -> None + evaluate() -> None + learn() -> None + setup() -> None +} +class "Protocol" as primaite.common.protocol.Protocol { + load : int + name : str + add_load(_load: int) -> None + clear_load() -> None + get_load() -> int + get_name() -> str +} +class "RLlibAgent" as primaite.agents.rllib.RLlibAgent { + {abstract}evaluate() -> None + {abstract}export() -> None + learn() -> None + {abstract}load(path: Union[str, Path]) -> RLlibAgent + save(overwrite_existing: bool) -> None +} +class "RandomAgent" as primaite.agents.simple.RandomAgent { +} +class "SB3Agent" as primaite.agents.sb3.SB3Agent { + is_eval : bool + evaluate() -> None + {abstract}export() -> None + learn() -> None + save() -> None +} +class "Service" as primaite.common.service.Service { + name : str + patching_count : int + port : str + software_state : GOOD + reduce_patching_count() -> None +} +class "ServiceNode" as primaite.nodes.service_node.ServiceNode { + services : Dict[str, Service] + add_service(service: Service) -> None + get_service_state(protocol_name: str) -> SoftwareState + has_service(protocol_name: str) -> bool + service_is_overwhelmed(protocol_name: str) -> bool + service_running(protocol_name: str) -> bool + set_service_state(protocol_name: str, software_state: SoftwareState) -> None + set_service_state_if_not_compromised(protocol_name: str, software_state: SoftwareState) -> None + update_booting_status() -> None + update_resetting_status() -> None + update_services_patching_status() -> None +} +class "SessionOutputWriter" as primaite.utils.session_output_writer.SessionOutputWriter { + learning_session : bool + transaction_writer : bool + close() -> None + write(data: Union[Tuple, Transaction]) -> None +} +class "TrainingConfig" as primaite.config.training_config.TrainingConfig { + action_type + agent_framework + agent_identifier + agent_load_file : Optional[str] + all_ok : float + checkpoint_every_n_episodes : int + compromised : float + compromised_should_be_good : float + compromised_should_be_overwhelmed : float + compromised_should_be_patching : float + corrupt : float + corrupt_should_be_destroyed : float + corrupt_should_be_good : float + corrupt_should_be_repairing : float + corrupt_should_be_restoring : float + deep_learning_framework + destroyed : float + destroyed_should_be_corrupt : float + destroyed_should_be_good : float + destroyed_should_be_repairing : float + destroyed_should_be_restoring : float + deterministic : bool + file_system_repairing_limit : int + file_system_restoring_limit : int + file_system_scanning_limit : int + good_should_be_compromised : float + good_should_be_corrupt : float + good_should_be_destroyed : float + good_should_be_overwhelmed : float + good_should_be_patching : float + good_should_be_repairing : float + good_should_be_restoring : float + green_ier_blocked : float + hard_coded_agent_view + implicit_acl_rule + load_agent : bool + max_number_acl_rules : int + node_booting_duration : int + node_reset_duration : int + node_shutdown_duration : int + num_eval_episodes : int + num_eval_steps : int + num_train_episodes : int + num_train_steps : int + observation_space : dict + observation_space_high_value : int + off_should_be_on : float + off_should_be_resetting : float + on_should_be_off : float + on_should_be_resetting : float + os_patching_duration : int + overwhelmed : float + overwhelmed_should_be_compromised : float + overwhelmed_should_be_good : float + overwhelmed_should_be_patching : float + patching : float + patching_should_be_compromised : float + patching_should_be_good : float + patching_should_be_overwhelmed : float + random_red_agent : bool + red_ier_running : float + repairing : float + repairing_should_be_corrupt : float + repairing_should_be_destroyed : float + repairing_should_be_good : float + repairing_should_be_restoring : float + resetting : float + resetting_should_be_off : float + resetting_should_be_on : float + restoring : float + restoring_should_be_corrupt : float + restoring_should_be_destroyed : float + restoring_should_be_good : float + restoring_should_be_repairing : float + sb3_output_verbose_level + scanning : float + seed : Optional[int] + service_patching_duration : int + session_type + time_delay : int + from_dict(config_dict: Dict[str, Any]) -> TrainingConfig + to_dict(json_serializable: bool) -> Dict +} +class "Transaction" as primaite.transactions.transaction.Transaction { + action_space : Optional[int] + agent_identifier + episode_number : int + obs_space : str + obs_space_description : NoneType, Optional[List[str]], list + obs_space_post : Optional[Union['np.ndarray', Tuple['np.ndarray']]] + obs_space_pre : Optional[Union['np.ndarray', Tuple['np.ndarray']]] + reward : Optional[float], float + step_number : int + timestamp : datetime + as_csv_data() -> Tuple[List, List] +} +primaite.agents.hardcoded_abc.HardCodedAgentSessionABC --|> primaite.agents.agent_abc.AgentSessionABC +primaite.agents.hardcoded_acl.HardCodedACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC +primaite.agents.hardcoded_node.HardCodedNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC +primaite.agents.rllib.RLlibAgent --|> primaite.agents.agent_abc.AgentSessionABC +primaite.agents.sb3.SB3Agent --|> primaite.agents.agent_abc.AgentSessionABC +primaite.agents.simple.DoNothingACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC +primaite.agents.simple.DoNothingNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC +primaite.agents.simple.DummyAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC +primaite.agents.simple.RandomAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC +primaite.environment.observations.AccessControlList_ --|> primaite.environment.observations.AbstractObservationComponent +primaite.environment.observations.LinkTrafficLevels --|> primaite.environment.observations.AbstractObservationComponent +primaite.environment.observations.NodeLinkTable --|> primaite.environment.observations.AbstractObservationComponent +primaite.environment.observations.NodeStatuses --|> primaite.environment.observations.AbstractObservationComponent +primaite.nodes.active_node.ActiveNode --|> primaite.nodes.node.Node +primaite.nodes.passive_node.PassiveNode --|> primaite.nodes.node.Node +primaite.nodes.service_node.ServiceNode --|> primaite.nodes.active_node.ActiveNode +primaite.common.service.Service --|> primaite.nodes.service_node.ServiceNode +primaite.acl.access_control_list.AccessControlList --* primaite.environment.primaite_env.Primaite : acl +primaite.acl.acl_rule.ACLRule --* primaite.acl.access_control_list.AccessControlList : acl_implicit_rule +primaite.agents.hardcoded_acl.HardCodedACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.hardcoded_node.HardCodedNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.rllib.RLlibAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.sb3.SB3Agent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.simple.DoNothingACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.simple.DoNothingNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.simple.DummyAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.agents.simple.RandomAgent --* primaite.primaite_session.PrimaiteSession : _agent_session +primaite.config.training_config.TrainingConfig --* primaite.agents.agent_abc.AgentSessionABC : _training_config +primaite.config.training_config.TrainingConfig --* primaite.environment.primaite_env.Primaite : training_config +primaite.environment.observations.ObservationsHandler --* primaite.environment.primaite_env.Primaite : obs_handler +primaite.environment.primaite_env.Primaite --* primaite.agents.agent_abc.AgentSessionABC : _env +primaite.environment.primaite_env.Primaite --* primaite.agents.hardcoded_abc.HardCodedAgentSessionABC : _env +primaite.environment.primaite_env.Primaite --* primaite.agents.sb3.SB3Agent : _env +primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : episode_av_reward_writer +primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : transaction_writer +primaite.config.training_config.TrainingConfig --o primaite.nodes.node.Node : config_values +primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen --* primaite.environment.primaite_env.Primaite +primaite.nodes.node_state_instruction_red.NodeStateInstructionRed --* primaite.environment.primaite_env.Primaite +primaite.pol.ier.IER --* primaite.environment.primaite_env.Primaite +primaite.common.protocol.Protocol --o primaite.links.link.Link +primaite.links.link.Link --* primaite.environment.primaite_env.Primaite +primaite.config.training_config.TrainingConfig --o primaite.nodes.active_node.ActiveNode +primaite.utils.session_output_writer.SessionOutputWriter --> primaite.transactions.transaction.Transaction +primaite.transactions.transaction.Transaction --> primaite.environment.primaite_env.Primaite +@enduml diff --git a/docs/_templates/custom-class-template.rst b/docs/_templates/custom-class-template.rst index acffdc4c..66acd325 100644 --- a/docs/_templates/custom-class-template.rst +++ b/docs/_templates/custom-class-template.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates. diff --git a/docs/_templates/custom-module-template.rst b/docs/_templates/custom-module-template.rst index 8eebad3e..64ac520a 100644 --- a/docs/_templates/custom-module-template.rst +++ b/docs/_templates/custom-module-template.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates. diff --git a/docs/api.rst b/docs/api.rst index b24dafc3..aeaef4e2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without diff --git a/docs/conf.py b/docs/conf.py index 8afc1246..4a805916 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Configuration file for the Sphinx documentation builder. # # For the full list of built-in configuration values, see the documentation: @@ -19,8 +19,8 @@ sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- year = datetime.datetime.now().year project = "PrimAITE" -copyright = f"Copyright (C) QinetiQ Training and Simulation Ltd 2021 - {year}" -author = "QinetiQ Training and Simulation Ltd" +copyright = f"Copyright (C) Defence Science and Technology Laboratory UK 2021 - {year}" +author = "Defence Science and Technology Laboratory UK" # The short Major.Minor.Build version with open("../src/primaite/VERSION", "r") as file: diff --git a/docs/index.rst b/docs/index.rst index de5bed46..208d5abc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK Welcome to PrimAITE's documentation ==================================== diff --git a/docs/source/about.rst b/docs/source/about.rst index 2068472c..d12a59de 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. _about: diff --git a/docs/source/config.rst b/docs/source/config.rst index 67bb86d8..daf7f90b 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. _config: @@ -17,10 +17,10 @@ PrimAITE uses two configuration files for its operation: Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules. -Environment Config: +Training Config: ******************* -The environment config file consists of the following attributes: +The Training Config file consists of the following attributes: **Generic Config Values** diff --git a/docs/source/custom_agent.rst b/docs/source/custom_agent.rst index ba438305..0f4f30ad 100644 --- a/docs/source/custom_agent.rst +++ b/docs/source/custom_agent.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK Custom Agents ============= @@ -130,7 +130,7 @@ Finally, specify your agent in your training config. .. code-block:: yaml - # ~/primaite/config/path/to/your/config_main.yaml + # ~/primaite/2.0.0rc2/config/path/to/your/config_main.yaml # Training Config File diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst index 0d3f21c3..942ccfd8 100644 --- a/docs/source/dependencies.rst +++ b/docs/source/dependencies.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. role:: raw-html(raw) :format: html diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 13c9d699..0ac2fdd4 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. _getting-started: @@ -41,12 +41,12 @@ Install PrimAITE .. code-tab:: bash :caption: Unix - mkdir ~/primaite + mkdir ~/primaite/2.0.0rc2 .. code-tab:: powershell :caption: Windows (Powershell) - mkdir ~\primaite + mkdir ~\primaite\2.0.0rc2 2. Navigate to the primaite directory and create a new python virtual environment (venv) @@ -55,13 +55,13 @@ Install PrimAITE .. code-tab:: bash :caption: Unix - cd ~/primaite + cd ~/primaite/2.0.0rc2 python3 -m venv .venv .. code-tab:: powershell :caption: Windows (Powershell) - cd ~\primaite + cd ~\primaite\2.0.0rc2 python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 3422d51e..8340d559 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK Glossary ============= @@ -77,5 +77,5 @@ Glossary Gym PrimAITE uses the Gym reinforcement learning framework API to create a training environment and interface with RL agents. Gym defines a common way of creating observations, actions, and rewards. - User data directory - PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite` on linux/darwin and `C:\Users\\primaite` on Windows. + User app home + PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite` on linux/darwin and `C:\Users\\primaite\` on Windows. diff --git a/docs/source/migration_1.2_-_2.0.rst b/docs/source/migration_1.2_-_2.0.rst index bc90a5c3..e1e24790 100644 --- a/docs/source/migration_1.2_-_2.0.rst +++ b/docs/source/migration_1.2_-_2.0.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK v1.2 to v2.0 Migration guide ============================ @@ -31,7 +31,7 @@ v1.2 to v2.0 Migration guide **3. Location of configs** - In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/config``. On Windows, this is stored in ``C:\Users\\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here. + In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/2.0.0rc2/config``. On Windows, this is stored in ``C:\Users\\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here. **4. Contents of configs** diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index c081d0d9..840e5717 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -1,6 +1,6 @@ .. only:: comment - Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK .. _run a primaite session: @@ -20,16 +20,16 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf .. code-tab:: bash :caption: Unix CLI - cd ~/primaite + cd ~/primaite/2.0.0rc2 source ./.venv/bin/activate - primaite session ./config/my_training_config.yaml ./config/my_lay_down_config.yaml + primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml .. code-tab:: powershell :caption: Powershell CLI - cd ~\primaite + cd ~\primaite\2.0.0rc2 .\.venv\Scripts\activate - primaite session .\config\my_training_config.yaml .\config\my_lay_down_config.yaml + primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml .. code-tab:: python @@ -41,11 +41,13 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf lay_down_config = run(training_config, lay_down_config) -When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/sessions``). -The sub-directory is formatted as such: ``~/primaite/sessions//_/`` +When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0rc2/sessions``). +The sub-directory is formatted as such: ``~/primaite/2.0.0rc2/sessions//_/`` For example, when running a session at 17:30:00 on 31st January 2023, the session will output to: -``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``. +``~/primaite/2.0.0rc2/sessions/2023-01-31/2023-01-31_17-30-00/``. + +``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``. Outputs @@ -108,43 +110,44 @@ For each training session, assuming the agent being trained implements the *save ~/ └── primaite/ - └── sessions/ - └── 2023-07-18/ - └── 2023-07-18_11-06-04/ - ├── evaluation/ - │ ├── all_transactions_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv - │ └── average_reward_per_episode_2023-07-18_11-06-04.png - ├── learning/ - │ ├── all_transactions_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.png - │ ├── checkpoints/ - │ │ └── sb3ppo_10.zip - │ ├── SB3_PPO.zip - │ └── tensorboard_logs/ - │ ├── PPO_1/ - │ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0 - │ ├── PPO_2/ - │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1 - │ ├── PPO_3/ - │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2 - │ ├── PPO_4/ - │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3 - │ ├── PPO_5/ - │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4 - │ ├── PPO_6/ - │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5 - │ ├── PPO_7/ - │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6 - │ ├── PPO_8/ - │ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7 - │ ├── PPO_9/ - │ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8 - │ └── PPO_10/ - │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9 - ├── network_2023-07-18_11-06-04.png - └── session_metadata.json + └── 2.0.0rc2/ + └── sessions/ + └── 2023-07-18/ + └── 2023-07-18_11-06-04/ + ├── evaluation/ + │ ├── all_transactions_2023-07-18_11-06-04.csv + │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv + │ └── average_reward_per_episode_2023-07-18_11-06-04.png + ├── learning/ + │ ├── all_transactions_2023-07-18_11-06-04.csv + │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv + │ ├── average_reward_per_episode_2023-07-18_11-06-04.png + │ ├── checkpoints/ + │ │ └── sb3ppo_10.zip + │ ├── SB3_PPO.zip + │ └── tensorboard_logs/ + │ ├── PPO_1/ + │ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0 + │ ├── PPO_2/ + │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1 + │ ├── PPO_3/ + │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2 + │ ├── PPO_4/ + │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3 + │ ├── PPO_5/ + │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4 + │ ├── PPO_6/ + │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5 + │ ├── PPO_7/ + │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6 + │ ├── PPO_8/ + │ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7 + │ ├── PPO_9/ + │ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8 + │ └── PPO_10/ + │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9 + ├── network_2023-07-18_11-06-04.png + └── session_metadata.json Loading a session ----------------- @@ -157,14 +160,14 @@ A previous session can be loaded by providing the **directory** of the previous .. code-tab:: bash :caption: Unix CLI - cd ~/primaite + cd ~/primaite/2.0.0rc2 source ./.venv/bin/activate primaite session --load "path/to/session" .. code-tab:: bash :caption: Powershell CLI - cd ~\primaite + cd ~\primaite\2.0.0rc2 .\.venv\Scripts\activate primaite session --load "path\to\session" diff --git a/pyproject.toml b/pyproject.toml index fc0551c3..9691f65c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,13 +5,13 @@ build-backend = "setuptools.build_meta" [project] name = "primaite" description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme." -authors = [{name="QinetiQ Training and Simulation Ltd"}] -license = {text = "GFX"} +authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}] +license = {file = "LICENSE"} requires-python = ">=3.8, <3.11" dynamic = ["version", "readme"] classifiers = [ - "License :: GFX", - "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Development Status :: 5 - Production/Stable", "Operating System :: Microsoft :: Windows", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", @@ -55,8 +55,10 @@ dev = [ "build==0.10.0", "flake8==6.0.0", "furo==2023.3.27", + "gputil==1.4.0", "pip-licenses==4.3.0", "pre-commit==2.20.0", + "pylatex==1.4.1", "pytest==7.2.0", "pytest-xdist==3.3.1", "pytest-cov==4.0.0", diff --git a/setup.py b/setup.py index efaf24bf..62bcbf16 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from setuptools import setup from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 4111d137..41ab234f 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0rc1 +2.0.0rc2 diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index c348681d..a0f5b7fe 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -1,23 +1,126 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import logging import logging.config +import shutil import sys from bisect import bisect from logging import Formatter, Logger, LogRecord, StreamHandler from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Any, Dict, Final +from typing import Any, Dict, Final, List import pkg_resources import yaml from platformdirs import PlatformDirs -_PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") -"""An instance of `PlatformDirs` set with appname='primaite'.""" +with open(Path(__file__).parent.resolve() / "VERSION", "r") as file: + __version__ = file.readline().strip() + + +class _PrimaitePaths: + """ + A Primaite paths class that leverages PlatformDirs. + + The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`. + """ + + def __init__(self): + self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) + + def _get_dirs_properties(self) -> List[str]: + class_items = self.__class__.__dict__.items() + return [k for k, v in class_items if isinstance(v, property)] + + def mkdirs(self): + """ + Creates all Primaite directories. + + Does this by retrieving all properties in the PrimaiteDirs class and calls each one. + """ + for p in self._get_dirs_properties(): + getattr(self, p) + + @property + def user_home_path(self) -> Path: + """The PrimAITE user home path.""" + path = Path.home() / "primaite" / __version__ + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def user_sessions_path(self) -> Path: + """The PrimAITE user sessions path.""" + path = self.user_home_path / "sessions" + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def user_config_path(self) -> Path: + """The PrimAITE user config path.""" + path = self.user_home_path / "config" + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def user_notebooks_path(self) -> Path: + """The PrimAITE user notebooks path.""" + path = self.user_home_path / "notebooks" + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def app_home_path(self) -> Path: + """The PrimAITE app home path.""" + path = self._dirs.user_data_path + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def app_config_dir_path(self) -> Path: + """The PrimAITE app config directory path.""" + path = self._dirs.user_config_path + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def app_config_file_path(self) -> Path: + """The PrimAITE app config file path.""" + return self.app_config_dir_path / "primaite_config.yaml" + + @property + def app_log_dir_path(self) -> Path: + """The PrimAITE app log directory path.""" + if sys.platform == "win32": + path = self.app_home_path / "logs" + else: + path = self._dirs.user_log_path + path.mkdir(exist_ok=True, parents=True) + return path + + @property + def app_log_file_path(self) -> Path: + """The PrimAITE app log file path.""" + return self.app_log_dir_path / "primaite.log" + + def __repr__(self): + properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()]) + return f"{self.__class__.__name__}({properties_str})" + + +PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths() + + +def _host_primaite_config(): + if not PRIMAITE_PATHS.app_config_file_path.exists(): + pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) + shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path) + + +_host_primaite_config() def _get_primaite_config() -> Dict: - config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" + config_path = PRIMAITE_PATHS.app_config_file_path if not config_path.exists(): config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) with open(config_path, "r") as file: @@ -36,35 +139,7 @@ def _get_primaite_config() -> Dict: _PRIMAITE_CONFIG = _get_primaite_config() -_USER_DIRS: Final[Path] = Path.home() / "primaite" -"""The users home space for PrimAITE which is located at: ~/primaite.""" -NOTEBOOKS_DIR: Final[Path] = _USER_DIRS / "notebooks" -""" -The path to the users notebooks directory as an instance of `Path` or -`PosixPath`, depending on the OS. - -Users notebooks are stored at: ``~/primaite/notebooks``. -""" - -USERS_CONFIG_DIR: Final[Path] = _USER_DIRS / "config" -""" -The path to the users config directory as an instance of `Path` or -`PosixPath`, depending on the OS. - -Users config files are stored at: ``~/primaite/config``. -""" - -SESSIONS_DIR: Final[Path] = _USER_DIRS / "sessions" -""" -The path to the users PrimAITE Sessions directory as an instance of `Path` or -`PosixPath`, depending on the OS. - -Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. -""" - - -# region Setup Logging class _LevelFormatter(Formatter): """ A custom level-specific formatter. @@ -87,14 +162,6 @@ class _LevelFormatter(Formatter): return formatter.format(record) -def _log_dir() -> Path: - if sys.platform == "win32": - dir_path = _PLATFORM_DIRS.user_data_path / "logs" - else: - dir_path = _PLATFORM_DIRS.user_log_path - return dir_path - - _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], @@ -105,18 +172,10 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( } ) -LOG_DIR: Final[Path] = _log_dir() -"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS.""" - -LOG_DIR.mkdir(exist_ok=True, parents=True) - -LOG_PATH: Final[Path] = LOG_DIR / "primaite.log" -"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS.""" - _STREAM_HANDLER: Final[StreamHandler] = StreamHandler() _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( - filename=LOG_PATH, + filename=PRIMAITE_PATHS.app_log_file_path, maxBytes=10485760, # 10MB backupCount=9, # Max 100MB of logs encoding="utf8", @@ -146,10 +205,3 @@ def getLogger(name: str) -> Logger: # noqa logger.setLevel(_PRIMAITE_CONFIG["log_level"]) return logger - - -# endregion - - -with open(Path(__file__).parent.resolve() / "VERSION", "r") as file: - __version__ = file.readline() diff --git a/src/primaite/acl/__init__.py b/src/primaite/acl/__init__.py index c6fd79f2..6dc02583 100644 --- a/src/primaite/acl/__init__.py +++ b/src/primaite/acl/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Access Control List. Models firewall functionality.""" diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index c61b0c10..88943f8f 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """A class that implements the access control list implementation for the network.""" import logging from typing import Dict, Final, List, Union diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index 53c860cd..9c8deacd 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """A class that implements an access control list rule.""" from primaite.common.enums import RulePermissionType diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py index d987b43f..c742daf3 100644 --- a/src/primaite/agents/__init__.py +++ b/src/primaite/agents/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Common interface between RL agents from different libraries and PrimAITE.""" diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index 3c18e1f3..c314fec3 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from __future__ import annotations import json @@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Union from uuid import uuid4 import primaite -from primaite import getLogger, SESSIONS_DIR +from primaite import getLogger, PRIMAITE_PATHS from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.data_viz.session_plots import plot_av_reward_per_episode @@ -25,14 +25,14 @@ def get_session_path(session_timestamp: datetime) -> Path: Get the directory path the session will output to. This is set in the format of: - ~/primaite/sessions//_. + ~/primaite/2.0.0rc2/sessions//_. :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_path + session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path @@ -230,7 +230,6 @@ class AgentSessionABC(ABC): self._update_session_metadata_file() self._can_evaluate = True self.is_eval = False - self._plot_av_reward_per_episode(learning_session=True) @abstractmethod def evaluate( @@ -243,9 +242,9 @@ class AgentSessionABC(ABC): :param kwargs: Any agent-specific key-word args to be passed. """ if self._can_evaluate: - self._plot_av_reward_per_episode(learning_session=False) self._update_session_metadata_file() self.is_eval = True + self._plot_av_reward_per_episode(learning_session=False) _LOGGER.info("Finished evaluation") @abstractmethod diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py index 0336f00e..e75edbc5 100644 --- a/src/primaite/agents/hardcoded_abc.py +++ b/src/primaite/agents/hardcoded_abc.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import time from abc import abstractmethod from pathlib import Path diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index b8c49c14..2440da06 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from typing import Dict, List, Union import numpy as np diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 10cc2b72..b08d8967 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import numpy as np from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index bde3a621..ab1b3af3 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,8 +1,9 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from __future__ import annotations import json import shutil +import zipfile from datetime import datetime from logging import Logger from pathlib import Path @@ -17,8 +18,9 @@ from ray.tune.registry import register_env from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType from primaite.environment.primaite_env import Primaite +from primaite.exceptions import RLlibAgentError _LOGGER: Logger = getLogger(__name__) @@ -68,11 +70,14 @@ class RLlibAgent(AgentSessionABC): # TODO: implement RLlib agent loading if session_path is not None: msg = "RLlib agent loading has not been implemented yet" - _LOGGER.error(msg) - print(msg) - raise NotImplementedError + _LOGGER.critical(msg) + raise NotImplementedError(msg) super().__init__(training_config_path, lay_down_config_path) + if self._training_config.session_type == SessionType.EVAL: + msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." + _LOGGER.critical(msg) + raise RLlibAgentError(msg) if not self._training_config.agent_framework == AgentFramework.RLLIB: msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) @@ -98,6 +103,7 @@ class RLlibAgent(AgentSessionABC): f"deep_learning_framework=" f"{self._training_config.deep_learning_framework}" ) + self._train_agent = None # Required to capture the learning agent to close after eval def _update_session_metadata_file(self) -> None: """ @@ -179,20 +185,73 @@ class RLlibAgent(AgentSessionABC): self._current_result = self._agent.train() self._save_checkpoint() self.save() - self._agent.stop() - super().learn() + # Done this way as the RLlib eval can only be performed if the session hasn't been stopped + if self._training_config.session_type is not SessionType.TRAIN: + self._train_agent = self._agent + else: + self._agent.stop() + self._plot_av_reward_per_episode(learning_session=True) + + def _unpack_saved_agent_into_eval(self) -> Path: + """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" + agent_restore_path = self.evaluation_path / "agent_restore" + if agent_restore_path.exists(): + shutil.rmtree(agent_restore_path) + agent_restore_path.mkdir() + with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: + zip_file.extractall(agent_restore_path) + return agent_restore_path + + def _setup_eval(self): + self._can_learn = False + self._can_evaluate = True + self._agent.restore(str(self._unpack_saved_agent_into_eval())) def evaluate( self, - **kwargs: None, - ) -> None: + **kwargs, + ): """ Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ - raise NotImplementedError + time_steps = self._training_config.num_eval_steps + episodes = self._training_config.num_eval_episodes + + self._setup_eval() + + self._env: Primaite = Primaite( + self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str + ) + + self._env.set_as_eval() + self.is_eval = True + if self._training_config.deterministic: + deterministic_str = "deterministic" + else: + deterministic_str = "non-deterministic" + _LOGGER.info( + f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." + ) + for episode in range(episodes): + obs = self._env.reset() + for step in range(time_steps): + action = self._agent.compute_single_action(observation=obs, explore=False) + + obs, rewards, done, info = self._env.step(action) + + self._env.reset() + self._env.close() + super().evaluate() + # Now we're safe to close the learning agent and write the mean rewards per episode for it + if self._training_config.session_type is not SessionType.TRAIN: + self._train_agent.stop() + self._plot_av_reward_per_episode(learning_session=True) + # Perform a clean-up of the unpacked agent + if (self.evaluation_path / "agent_restore").exists(): + shutil.rmtree((self.evaluation_path / "agent_restore")) def _get_latest_checkpoint(self) -> None: raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 5a9f9482..783f57eb 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from __future__ import annotations import json @@ -153,6 +153,8 @@ class SB3Agent(AgentSessionABC): # save agent self.save() + self._plot_av_reward_per_episode(learning_session=True) + def evaluate( self, **kwargs: Any, diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index 18ffa72b..bfdff869 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import numpy as np diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index ff0ca8d2..08d46294 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from typing import Dict, List, Union import numpy as np diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 14db236c..c3e173af 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -1,18 +1,15 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Provides a CLI using Typer as an entry point.""" import logging import os -import shutil from enum import Enum -from pathlib import Path from typing import Optional -import pkg_resources import typer import yaml -from platformdirs import PlatformDirs from typing_extensions import Annotated +from primaite import PRIMAITE_PATHS from primaite.data_viz import PlotlyTemplate app = typer.Typer() @@ -47,10 +44,10 @@ def logs(last_n: Annotated[int, typer.Option("-n")]) -> None: """ import re - from primaite import LOG_PATH + from primaite import PRIMAITE_PATHS - if os.path.isfile(LOG_PATH): - with open(LOG_PATH) as file: + if os.path.isfile(PRIMAITE_PATHS.app_log_file_path): + with open(PRIMAITE_PATHS.app_log_file_path) as file: lines = file.readlines() for line in lines[-last_n:]: print(re.sub(r"\n*", "", line)) @@ -70,16 +67,13 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> For example, to set the to debug, call: primaite log-level DEBUG """ - app_dirs = PlatformDirs(appname="primaite") - app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) - user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - if user_config_path.exists(): - with open(user_config_path, "r") as file: + if PRIMAITE_PATHS.app_config_file_path.exists(): + with open(PRIMAITE_PATHS.app_config_file_path, "r") as file: primaite_config = yaml.safe_load(file) if level: primaite_config["logging"]["log_level"] = level.value - with open(user_config_path, "w") as file: + with open(PRIMAITE_PATHS.app_config_file_path, "w") as file: yaml.dump(primaite_config, file) print(f"PrimAITE Log Level: {level}") else: @@ -118,16 +112,8 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ - # Does this way to avoid using PrimAITE package before config is loaded - app_dirs = PlatformDirs(appname="primaite") - app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) - user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) - - shutil.copy2(pkg_config_path, user_config_path) - from primaite import getLogger - from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs + from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs _LOGGER = getLogger(__name__) @@ -136,7 +122,7 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Building primaite_config.yaml...") _LOGGER.info("Building the PrimAITE app directories...") - setup_app_dirs.run() + PRIMAITE_PATHS.mkdirs() _LOGGER.info("Rebuilding the demo notebooks...") reset_demo_notebooks.run(overwrite_existing=True) @@ -157,11 +143,11 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[ tc: The training config filepath. Optional. If no value is passed then example default training config is used from: - ~/primaite/config/example_config/training/training_config_main.yaml. + ~/primaite/2.0.0rc2/config/example_config/training/training_config_main.yaml. ldc: The lay down config file path. Optional. If no value is passed then example default lay down config is used from: - ~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. + ~/primaite/2.0.0rc2/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. load: The directory of a previous session. Optional. If no value is passed, then the session will use the default training config and laydown config. Inversely, if a training config and laydown config @@ -173,15 +159,18 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[ from primaite.main import run if load is not None: + # run a loaded session run(session_path=load) - if not tc: - tc = main_training_config_path() + else: + # start a new session using tc and ldc + if not tc: + tc = main_training_config_path() - if not ldc: - ldc = dos_very_basic_config_path() + if not ldc: + ldc = dos_very_basic_config_path() - run(training_config_path=tc, lay_down_config_path=ldc) + run(training_config_path=tc, lay_down_config_path=ldc) @app.command() @@ -195,16 +184,13 @@ def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK """ - app_dirs = PlatformDirs(appname="primaite") - app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) - user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - if user_config_path.exists(): - with open(user_config_path, "r") as file: + if PRIMAITE_PATHS.app_config_file_path.exists(): + with open(PRIMAITE_PATHS.app_config_file_path, "r") as file: primaite_config = yaml.safe_load(file) if template: primaite_config["session"]["outputs"]["plots"]["template"] = template.value - with open(user_config_path, "w") as file: + with open(PRIMAITE_PATHS.app_config_file_path, "w") as file: yaml.dump(primaite_config, file) print(f"PrimAITE plotly template: {template.value}") else: diff --git a/src/primaite/common/__init__.py b/src/primaite/common/__init__.py index 738a30d1..5770bcbc 100644 --- a/src/primaite/common/__init__.py +++ b/src/primaite/common/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Objects which are shared between many PrimAITE modules.""" diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index d74ec795..006301f1 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Enumerations for APE.""" from enum import Enum, IntEnum diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py index 048ed0ab..6940ba3f 100644 --- a/src/primaite/common/protocol.py +++ b/src/primaite/common/protocol.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The protocol class.""" diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py index 7ee694db..956815e8 100644 --- a/src/primaite/common/service.py +++ b/src/primaite/common/service.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The Service class.""" from primaite.common.enums import SoftwareState diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py index 9bd899f7..92f5a7d2 100644 --- a/src/primaite/config/__init__.py +++ b/src/primaite/config/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Configuration parameters for running experiments.""" diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 9cadc509..65ca7e91 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,15 +1,15 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from logging import Logger from pathlib import Path from typing import Any, Dict, Final, Union import yaml -from primaite import getLogger, USERS_CONFIG_DIR +from primaite import getLogger, PRIMAITE_PATHS _LOGGER: Logger = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down" def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 56402bfb..ebfee09a 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from __future__ import annotations from dataclasses import dataclass, field @@ -8,7 +8,7 @@ from typing import Any, Dict, Final, Optional, Union import yaml -from primaite import getLogger, USERS_CONFIG_DIR +from primaite import getLogger, PRIMAITE_PATHS from primaite.common.enums import ( ActionType, AgentFramework, @@ -22,7 +22,7 @@ from primaite.common.enums import ( _LOGGER: Logger = getLogger(__name__) -_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" +_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training" def main_training_config_path() -> Path: @@ -246,6 +246,7 @@ class TrainingConfig: return data def __str__(self) -> str: + obs_str = ",".join([c["name"] for c in self.observation_space["components"]]) tc = f"{self.agent_framework}, " if self.agent_framework is AgentFramework.RLLIB: tc += f"{self.deep_learning_framework}, " @@ -253,7 +254,7 @@ class TrainingConfig: if self.agent_identifier is AgentIdentifier.HARDCODED: tc += f"{self.hard_coded_agent_view}, " tc += f"{self.action_type}, " - tc += f"observation_space={self.observation_space}, " + tc += f"observation_space={obs_str}, " if self.session_type is SessionType.TRAIN: tc += f"{self.num_train_episodes} episodes @ " tc += f"{self.num_train_steps} steps" diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py index ad43c141..260579da 100644 --- a/src/primaite/data_viz/__init__.py +++ b/src/primaite/data_viz/__init__.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Utility to generate plots of sessions metrics after PrimAITE.""" from enum import Enum diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py index 39c2b4cc..37750353 100644 --- a/src/primaite/data_viz/session_plots.py +++ b/src/primaite/data_viz/session_plots.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Dict, Optional, Union @@ -7,13 +7,12 @@ import polars as pl import yaml from plotly.graph_objs import Figure -from primaite import _PLATFORM_DIRS +from primaite import PRIMAITE_PATHS -def _get_plotly_config() -> Dict: +def get_plotly_config() -> Dict: """Get the plotly config from primaite_config.yaml.""" - user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" - with open(user_config_path, "r") as file: + with open(PRIMAITE_PATHS.app_config_file_path, "r") as file: primaite_config = yaml.safe_load(file) return primaite_config["session"]["outputs"]["plots"] @@ -41,7 +40,7 @@ def plot_av_reward_per_episode( if subtitle: title = subtitle - config = _get_plotly_config() + config = get_plotly_config() layout = go.Layout( autosize=config["size"]["auto_size"], width=config["size"]["width"], diff --git a/src/primaite/environment/__init__.py b/src/primaite/environment/__init__.py index e837fe1e..f0fd21b9 100644 --- a/src/primaite/environment/__init__.py +++ b/src/primaite/environment/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Gym/Gymnasium environment for RL agents consisting of a simulated computer network.""" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a0423b89..383a9b5a 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index bd9b3689..cde586ed 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy import logging @@ -261,10 +261,12 @@ class Primaite(Env): self, transaction_writer=True, learning_session=True ) + self.is_eval = False + @property def actual_episode_count(self) -> int: - """Shifts the episode_count by -1 for RLlib.""" - if self.training_config.agent_framework is AgentFramework.RLLIB: + """Shifts the episode_count by -1 for RLlib learning session.""" + if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: return self.episode_count - 1 return self.episode_count @@ -276,6 +278,7 @@ class Primaite(Env): self.step_count = 0 self.total_step_count = 0 self.episode_steps = self.training_config.num_eval_steps + self.is_eval = True def _write_av_reward_per_episode(self) -> None: if self.actual_episode_count > 0: diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 92ef89ec..aa9dc97d 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Implements reward function.""" from logging import Logger from typing import Dict, TYPE_CHECKING, Union diff --git a/src/primaite/exceptions.py b/src/primaite/exceptions.py new file mode 100644 index 00000000..3b4058ac --- /dev/null +++ b/src/primaite/exceptions.py @@ -0,0 +1,11 @@ +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +class PrimaiteError(Exception): + """The root PrimAITe Error.""" + + pass + + +class RLlibAgentError(PrimaiteError): + """Raised when there is a generic error with a RLlib agent that is specific to PRimAITE.""" + + pass diff --git a/src/primaite/links/__init__.py b/src/primaite/links/__init__.py index 21ce44ba..c91b6951 100644 --- a/src/primaite/links/__init__.py +++ b/src/primaite/links/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Network connections between nodes in the simulation.""" diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index aa3fa7fb..3830a15b 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The link class.""" from typing import List diff --git a/src/primaite/main.py b/src/primaite/main.py index aed39d73..03f4fb35 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The main PrimAITE session runner module.""" import argparse from pathlib import Path diff --git a/src/primaite/nodes/__init__.py b/src/primaite/nodes/__init__.py index 43b213d6..231b8d92 100644 --- a/src/primaite/nodes/__init__.py +++ b/src/primaite/nodes/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Nodes represent network hosts in the simulation.""" diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index b5df70b5..8f472e86 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """An Active Node (i.e. not an actuator).""" import logging from typing import Final diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 9118fa46..fc4d41d3 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The base Node class.""" from typing import Final diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 8e03b40f..6e35d0ec 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Defines node behaviour for Green PoL.""" from typing import TYPE_CHECKING, Union diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 786e93ac..eb87924b 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,6 +1,5 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Defines node behaviour for Green PoL.""" -from dataclasses import dataclass from typing import TYPE_CHECKING, Union from primaite.common.enums import NodePOLType @@ -9,8 +8,7 @@ if TYPE_CHECKING: from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState -@dataclass() -class NodeStateInstructionRed(object): +class NodeStateInstructionRed: """The Node State Instruction class.""" def __init__( diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index 88c8cc85..08dcbfa2 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The Passive Node class (i.e. an actuator).""" from primaite.common.enums import HardwareState, NodeType, Priority from primaite.config.training_config import TrainingConfig diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index ce1ffe92..b0d42785 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """A Service Node (i.e. not an actuator).""" import logging from typing import Dict, Final diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 390fddb4..bc1dcfcd 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Contains default jupyter notebooks which demonstrate PrimAITE functionality.""" import importlib.util @@ -7,7 +7,7 @@ import subprocess import sys from logging import Logger -from primaite import getLogger, NOTEBOOKS_DIR +from primaite import getLogger, PRIMAITE_PATHS _LOGGER: Logger = getLogger(__name__) @@ -26,7 +26,7 @@ def start_jupyter_session() -> None: jupyter_cmd = "jupyter lab" working_dir = os.getcwd() - os.chdir(NOTEBOOKS_DIR) + os.chdir(PRIMAITE_PATHS.user_notebooks_path) subprocess.Popen(jupyter_cmd) os.chdir(working_dir) else: diff --git a/src/primaite/pol/__init__.py b/src/primaite/pol/__init__.py index 1adb1491..d0d9f616 100644 --- a/src/primaite/pol/__init__.py +++ b/src/primaite/pol/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Pattern of Life- Represents the actions of users on the network.""" diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 0425a831..814aa314 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Implements Pattern of Life on the network (nodes and links).""" from typing import Dict diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py index 7fab340d..b8da28c0 100644 --- a/src/primaite/pol/ier.py +++ b/src/primaite/pol/ier.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """ Information Exchange Requirements for APE. diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index ad55fa24..ca1a58da 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Implements POL on the network (nodes and links) resulting from the red agent attack.""" from typing import Dict @@ -250,6 +250,11 @@ def apply_red_agent_node_pol( # continue -------------------------- target_node: NodeUnion = nodes[target_node_id] + # check if the initiator type is a str, and if so, cast it as + # NodePOLInitiator + if isinstance(initiator, str): + initiator = NodePOLInitiator[initiator] + # Based the action taken on the initiator type if initiator == NodePOLInitiator.DIRECT: # No conditions required, just apply the change diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index ab3c2150..c64b51fb 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -1,9 +1,10 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Main entry point to PrimAITE. Configure training/evaluation experiments and input/output.""" from __future__ import annotations +import json from pathlib import Path -from typing import Any, Dict, Final, Optional, Union +from typing import Any, Dict, Final, Optional, Tuple, Union from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC @@ -16,6 +17,7 @@ from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, S from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.utils.session_metadata_parser import parse_session_metadata +from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict _LOGGER = getLogger(__name__) @@ -70,13 +72,7 @@ class PrimaiteSession: if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - - self._agent_session: AgentSessionABC = None # noqa - self.session_path: Path = None # noqa - self.timestamp_str: str = None # noqa - self.learning_path: Path = None # noqa - self.evaluation_path: Path = None # noqa + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) # noqa def setup(self) -> None: """Performs the session setup.""" @@ -186,3 +182,28 @@ class PrimaiteSession: def close(self) -> None: """Closes the agent.""" self._agent_session.close() + + def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: + """Get the learn av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.learning_path / csv_file) + + def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: + """Get the eval av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.evaluation_path / csv_file) + + def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the learn all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.learning_path / csv_file) + + def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the eval all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.evaluation_path / csv_file) + + def metadata_file_as_dict(self) -> Dict[str, Any]: + """Read the session_metadata.json file and return as a dict.""" + with open(self.session_path / "session_metadata.json", "r") as file: + return json.load(file) diff --git a/src/primaite/setup/__init__.py b/src/primaite/setup/__init__.py index acfa48c4..12b77f1e 100644 --- a/src/primaite/setup/__init__.py +++ b/src/primaite/setup/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Utilities to prepare the user's data folders.""" diff --git a/src/primaite/setup/old_installation_clean_up.py b/src/primaite/setup/old_installation_clean_up.py index d23abf3c..412aed60 100644 --- a/src/primaite/setup/old_installation_clean_up.py +++ b/src/primaite/setup/old_installation_clean_up.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from primaite import getLogger diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index f47af1dc..1f96c90f 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import filecmp import os import shutil @@ -7,7 +7,7 @@ from pathlib import Path import pkg_resources -from primaite import getLogger, NOTEBOOKS_DIR +from primaite import getLogger, PRIMAITE_PATHS _LOGGER: Logger = getLogger(__name__) @@ -23,7 +23,7 @@ def run(overwrite_existing: bool = True) -> None: for file in files: fp = os.path.join(subdir, file) path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep) - target_fp = NOTEBOOKS_DIR / Path(*path_split) + target_fp = PRIMAITE_PATHS.user_notebooks_path / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index 68ce588c..41345853 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import filecmp import os import shutil @@ -6,7 +6,7 @@ from pathlib import Path import pkg_resources -from primaite import getLogger, USERS_CONFIG_DIR +from primaite import getLogger, PRIMAITE_PATHS _LOGGER = getLogger(__name__) @@ -23,7 +23,7 @@ def run(overwrite_existing: bool = True) -> None: for file in files: fp = os.path.join(subdir, file) path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep) - target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split) + target_fp = PRIMAITE_PATHS.user_config_path / "example_config" / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py deleted file mode 100644 index 68b5d772..00000000 --- a/src/primaite/setup/setup_app_dirs.py +++ /dev/null @@ -1,29 +0,0 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. -from logging import Logger - -from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR - -_LOGGER: Logger = getLogger(__name__) - - -def run() -> None: - """ - Handles creation of application directories and user directories. - - Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required - app directories in the correct locations based on the users OS. - """ - app_dirs = [ - _USER_DIRS, - NOTEBOOKS_DIR, - LOG_DIR, - ] - - for app_dir in app_dirs: - if not app_dir.is_dir(): - app_dir.mkdir(parents=True, exist_ok=True) - _LOGGER.info(f"Created directory: {app_dir}") - - -if __name__ == "__main__": - run() diff --git a/src/primaite/transactions/__init__.py b/src/primaite/transactions/__init__.py index 9a881fd5..505c5080 100644 --- a/src/primaite/transactions/__init__.py +++ b/src/primaite/transactions/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Record data of the system's state and agent's observations and actions.""" diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 1a702748..7d5f747c 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """The Transaction class.""" from datetime import datetime from typing import List, Optional, Tuple, TYPE_CHECKING, Union diff --git a/src/primaite/utils/__init__.py b/src/primaite/utils/__init__.py index 5dbd1e5f..4f9deb13 100644 --- a/src/primaite/utils/__init__.py +++ b/src/primaite/utils/__init__.py @@ -1,2 +1,2 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Utilities for PrimAITE.""" diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py index 96157b40..ac41e8bc 100644 --- a/src/primaite/utils/package_data.py +++ b/src/primaite/utils/package_data.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import os from logging import Logger from pathlib import Path diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py index 0b0eaaec..2548a8b6 100644 --- a/src/primaite/utils/session_metadata_parser.py +++ b/src/primaite/utils/session_metadata_parser.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import json from pathlib import Path from typing import Any, Dict, Union diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index 7089c69a..30febff1 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Any, Dict, Tuple, Union @@ -18,7 +18,7 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: """ df_dict = pl.read_csv(av_rewards_csv_file).to_dict() - return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])} + return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])} def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]: diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index e7f1b248..0eb18038 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import csv from logging import Logger from typing import Final, List, Tuple, TYPE_CHECKING, Union diff --git a/tests/__init__.py b/tests/__init__.py index f8e6fc55..5a06b646 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from pathlib import Path from typing import Final diff --git a/tests/config/legacy_conversion/legacy_training_config.yaml b/tests/config/legacy_conversion/legacy_training_config.yaml index fb24e3d7..3477e6e0 100644 --- a/tests/config/legacy_conversion/legacy_training_config.yaml +++ b/tests/config/legacy_conversion/legacy_training_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Main Config File # Generic config values diff --git a/tests/config/legacy_conversion/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml index 3df29d04..1ec36e97 100644 --- a/tests/config/legacy_conversion/new_training_config.yaml +++ b/tests/config/legacy_conversion/new_training_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Main Config File # Generic config values diff --git a/tests/config/obs_tests/laydown.yaml b/tests/config/obs_tests/laydown.yaml index 4ab44755..e358d0d2 100644 --- a/tests/config/obs_tests/laydown.yaml +++ b/tests/config/obs_tests/laydown.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - item_type: PORTS ports_list: - port: '80' diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml index 689d6bb4..805ab31e 100644 --- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml index 885f7e79..535558aa 100644 --- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml index c662e715..d1319c35 100644 --- a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml index a2af9763..26457c84 100644 --- a/tests/config/obs_tests/main_config_without_obs.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index 65257d62..0f572d8d 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - item_type: PORTS ports_list: - port: '21' diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index dbe4256f..10af7a1f 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml index 2160a3a3..fac2fe95 100644 --- a/tests/config/ppo_not_seeded_training_config.yaml +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/ppo_seeded_training_config.yaml b/tests/config/ppo_seeded_training_config.yaml index 7512dc85..e4d4fe5b 100644 --- a/tests/config/ppo_seeded_training_config.yaml +++ b/tests/config/ppo_seeded_training_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/training_config_main_rllib.yaml b/tests/config/session_test/training_config_main_rllib.yaml similarity index 98% rename from tests/config/training_config_main_rllib.yaml rename to tests/config/session_test/training_config_main_rllib.yaml index 40cbc0fc..374c6ac5 100644 --- a/tests/config/training_config_main_rllib.yaml +++ b/tests/config/session_test/training_config_main_rllib.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. @@ -69,7 +69,7 @@ num_train_episodes: 10 num_train_steps: 256 # Number of episodes for evaluation to run per session -num_eval_episodes: 1 +num_eval_episodes: 3 # Number of time_steps for evaluation per episode num_eval_steps: 256 diff --git a/tests/config/session_test/training_config_main_sb3.yaml b/tests/config/session_test/training_config_main_sb3.yaml new file mode 100644 index 00000000..733105ea --- /dev/null +++ b/tests/config/session_test/training_config_main_sb3.yaml @@ -0,0 +1,164 @@ +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + +# Number of episodes for evaluation to run per session +num_eval_episodes: 3 + +# Number of time_steps for evaluation per episode +num_eval_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 644d5912..6210cf3e 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index 866eebe8..9103e2b7 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - item_type: PORTS ports_list: - port: '80' diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index deaad090..67eaf49d 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml index 3416029c..310c9dc6 100644 --- a/tests/config/test_random_red_main_config.yaml +++ b/tests/config/test_random_red_main_config.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/config/train_episode_step.yaml b/tests/config/train_episode_step.yaml index 31337b0c..a86e0f62 100644 --- a/tests/config/train_episode_step.yaml +++ b/tests/config/train_episode_step.yaml @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # Training Config File # Sets which agent algorithm framework will be used. diff --git a/tests/conftest.py b/tests/conftest.py index 9b0db139..f40b0b94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,10 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import datetime -import json import shutil import tempfile from datetime import datetime from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Union from unittest.mock import patch import pytest @@ -13,7 +12,6 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession -from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -37,31 +35,6 @@ class TempPrimaiteSession(PrimaiteSession): super().__init__(training_config_path, lay_down_config_path) self.setup() - def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: - """Get the learn av reward per episode from file.""" - csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" - return av_rewards_dict(self.learning_path / csv_file) - - def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: - """Get the eval av reward per episode from file.""" - csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" - return av_rewards_dict(self.evaluation_path / csv_file) - - def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: - """Get the learn all transactions from file.""" - csv_file = f"all_transactions_{self.timestamp_str}.csv" - return all_transactions_dict(self.learning_path / csv_file) - - def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: - """Get the eval all transactions from file.""" - csv_file = f"all_transactions_{self.timestamp_str}.csv" - return all_transactions_dict(self.evaluation_path / csv_file) - - def metadata_file_as_dict(self) -> Dict[str, Any]: - """Read the session_metadata.json file and return as a dict.""" - with open(self.session_path / "session_metadata.json", "r") as file: - return json.load(file) - @property def env(self) -> Primaite: """Direct access to the env for ease of testing.""" diff --git a/tests/mock_and_patch/__init__.py b/tests/mock_and_patch/__init__.py index 778748f7..e0315ce3 100644 --- a/tests/mock_and_patch/__init__.py +++ b/tests/mock_and_patch/__init__.py @@ -1 +1 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index 190e1dba..16c4a274 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import tempfile from datetime import datetime from pathlib import Path diff --git a/tests/test_acl.py b/tests/test_acl.py index 3491aab8..d8357cf6 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Used to tes the ACL functions.""" from primaite.acl.access_control_list import AccessControlList diff --git a/tests/test_active_node.py b/tests/test_active_node.py index 880c0f02..44d38313 100644 --- a/tests/test_active_node.py +++ b/tests/test_active_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Used to test Active Node functions.""" import pytest diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index c4a9789c..ff3528e1 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Test env creation and behaviour with different observation spaces.""" import numpy as np diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index 210d931e..b76a2ecf 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -1,22 +1,29 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import os import pytest from primaite import getLogger from primaite.config.lay_down_config import dos_very_basic_config_path -from primaite.config.training_config import main_training_config_path +from tests import TEST_CONFIG_ROOT _LOGGER = getLogger(__name__) @pytest.mark.parametrize( "temp_primaite_session", - [[main_training_config_path(), dos_very_basic_config_path()]], + [ + [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()], + [TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()], + ], indirect=True, ) def test_primaite_session(temp_primaite_session): - """Tests the PrimaiteSession class and its outputs.""" + """ + Tests the PrimaiteSession class and all of its outputs. + + This test runs for both a Stable Baselines3 agent, and a Ray RLlib agent. + """ with temp_primaite_session as session: session_path = session.session_path assert session_path.exists() @@ -47,6 +54,17 @@ def test_primaite_session(temp_primaite_session): if file.suffix == ".csv": assert "all_transactions" in file.name or "average_reward_per_episode" in file.name + # Check that the average reward per episode plots exist + assert (session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists() + assert (session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists() + + # Check that the metadata has captured the correct number of learning and eval episodes and steps + assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 + assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 + + assert len(session.eval_av_reward_per_episode_dict().keys()) == 3 + assert len(session.eval_all_transactions_dict().keys()) == 3 * 256 + _LOGGER.debug("Inspecting files in temp session path...") for dir_path, dir_names, file_names in os.walk(session_path): for file in file_names: diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 3496ed9d..e99f4adb 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import pytest from primaite.config.lay_down_config import data_manipulation_config_path diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index 80e13c5b..d4e27c17 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Used to test Active Node functions.""" import pytest diff --git a/tests/test_reward.py b/tests/test_reward.py index 741c6f13..2ac66af1 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import pytest from primaite import getLogger diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py deleted file mode 100644 index f494ea81..00000000 --- a/tests/test_rllib_agent.py +++ /dev/null @@ -1,24 +0,0 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. -import pytest - -from primaite import getLogger -from primaite.config.lay_down_config import dos_very_basic_config_path -from tests import TEST_CONFIG_ROOT - -_LOGGER = getLogger(__name__) - - -@pytest.mark.parametrize( - "temp_primaite_session", - [[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]], - indirect=True, -) -def test_primaite_session(temp_primaite_session): - """Test the training_config_main_rllib.yaml training config file.""" - with temp_primaite_session as session: - session_path = session.session_path - assert session_path.exists() - session.learn() - - assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 - assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index c4b47d5f..9500c4a3 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import pytest as pytest from primaite.config.lay_down_config import dos_very_basic_config_path @@ -50,7 +50,6 @@ def test_seeded_learning(temp_primaite_session): assert actual_mean_reward_per_episode == expected_mean_reward_per_episode -@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL knowledge to investigate further.") @pytest.mark.parametrize( "temp_primaite_session", [[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]], diff --git a/tests/test_service_node.py b/tests/test_service_node.py index 2f504cd6..906bcf55 100644 --- a/tests/test_service_node.py +++ b/tests/test_service_node.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK """Used to test Service Node functions.""" import pytest diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py index c624e200..f9990f76 100644 --- a/tests/test_session_loading.py +++ b/tests/test_session_loading.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import os.path import shutil import tempfile @@ -6,10 +6,11 @@ from pathlib import Path from typing import Union from uuid import uuid4 -import pytest +from typer.testing import CliRunner from primaite import getLogger from primaite.agents.sb3 import SB3Agent +from primaite.cli import app from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.main import run from primaite.primaite_session import PrimaiteSession @@ -18,6 +19,24 @@ from tests import TEST_ASSETS_ROOT _LOGGER = getLogger(__name__) +runner = CliRunner() + +sb3_expected_avg_reward_per_episode = { + 10: 0.0, + 11: -0.0011074218750000008, + 12: -0.0010000000000000007, + 13: -0.0016601562500000013, + 14: -0.001400390625000001, + 15: -0.0009863281250000007, + 16: -0.0011855468750000008, + 17: -0.0009511718750000007, + 18: -0.0008789062500000007, + 19: -0.0012226562500000009, + 20: -0.0010292968750000007, +} + +sb3_expected_eval_rewards = -0.0018515625000000014 + def copy_session_asset(asset_path: Union[str, Path]) -> str: """Copies the asset into a temporary test folder.""" @@ -43,25 +62,8 @@ def copy_session_asset(asset_path: Union[str, Path]) -> str: return copy_path -@pytest.mark.xfail( - reason="Loading works fine but the exact values change with code changes, a bug report has been created." -) def test_load_sb3_session(): """Test that loading an SB3 agent works.""" - expected_learn_mean_reward_per_episode = { - 10: 0, - 11: -0.008037109374999995, - 12: -0.007978515624999988, - 13: -0.008191406249999991, - 14: -0.00817578124999999, - 15: -0.008085937499999998, - 16: -0.007837890624999982, - 17: -0.007798828124999992, - 18: -0.007777343749999998, - 19: -0.007958984374999988, - 20: -0.0077499999999999835, - } - test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") loaded_agent = SB3Agent(session_path=test_path) @@ -82,7 +84,7 @@ def test_load_sb3_session(): ) # run is seeded so should have the expected learn value - assert learn_mean_rewards == expected_learn_mean_reward_per_episode + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode # run an evaluation loaded_agent.evaluate() @@ -96,29 +98,14 @@ def test_load_sb3_session(): assert len(set(eval_mean_reward.values())) == 1 # the evaluation should be the same as a previous run - assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards # delete the test directory shutil.rmtree(test_path) -@pytest.mark.xfail(reason="Temporarily don't worry about this not working") def test_load_primaite_session(): """Test that loading a Primaite session works.""" - expected_learn_mean_reward_per_episode = { - 10: 0, - 11: -0.008037109374999995, - 12: -0.007978515624999988, - 13: -0.008191406249999991, - 14: -0.00817578124999999, - 15: -0.008085937499999998, - 16: -0.007837890624999982, - 17: -0.007798828124999992, - 18: -0.007777343749999998, - 19: -0.007958984374999988, - 20: -0.0077499999999999835, - } - test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") # create loaded session @@ -143,7 +130,7 @@ def test_load_primaite_session(): ) # run is seeded so should have the expected learn value - assert learn_mean_rewards == expected_learn_mean_reward_per_episode + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode # run an evaluation session.evaluate() @@ -157,29 +144,14 @@ def test_load_primaite_session(): assert len(set(eval_mean_reward.values())) == 1 # the evaluation should be the same as a previous run - assert next(iter(set(eval_mean_reward.values()))) == -0.009896484374999988 + assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards # delete the test directory shutil.rmtree(test_path) -@pytest.mark.xfail(reason="Temporarily don't worry about this not working") def test_run_loading(): """Test loading session via main.run.""" - expected_learn_mean_reward_per_episode = { - 10: 0, - 11: -0.008037109374999995, - 12: -0.007978515624999988, - 13: -0.008191406249999991, - 14: -0.00817578124999999, - 15: -0.008085937499999998, - 16: -0.007837890624999982, - 17: -0.007798828124999992, - 18: -0.007777343749999998, - 19: -0.007958984374999988, - 20: -0.0077499999999999835, - } - test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") # create loaded session @@ -190,7 +162,26 @@ def test_run_loading(): ) # run is seeded so should have the expected learn value - assert learn_mean_rewards == expected_learn_mean_reward_per_episode + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode + + # delete the test directory + shutil.rmtree(test_path) + + +def test_cli(): + """Test loading session via CLI.""" + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + result = runner.invoke(app, ["session", "--load", test_path]) + + # cli should work + assert result.exit_code == 0 + + learn_mean_rewards = av_rewards_dict( + next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None) + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == sb3_expected_avg_reward_per_episode # delete the test directory shutil.rmtree(test_path) diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index b91bc2bf..5d300232 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import time import pytest diff --git a/tests/test_train_eval_episode_steps.py b/tests/test_train_eval_episode_steps.py index 4f7bed16..1b53fe9d 100644 --- a/tests/test_train_eval_episode_steps.py +++ b/tests/test_train_eval_episode_steps.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import pytest from primaite import getLogger diff --git a/tests/test_training_config.py b/tests/test_training_config.py index 4123ee39..58f9c797 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -1,4 +1,4 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK import yaml from primaite.config import training_config