Merge remote-tracking branch 'origin/dev' into release/2.0.0rc1

# Conflicts:
#	.gitignore
#	benchmark/results/PrimAITE Versions Learning Benchmark.png
#	benchmark/results/v2.0.0rc1/PrimAITE v2.0.0rc1 Learning Benchmark.pdf
#	benchmark/results/v2.0.0rc1/PrimAITE v2.0.0rc1 Learning Benchmark.png
#	benchmark/results/v2.0.0rc1/v2.0.0rc1_benchmark_metadata.json
This commit is contained in:
Chris McCarthy
2023-07-24 22:50:34 +01:00
111 changed files with 1305 additions and 440 deletions

90
CHANGELOG.md Normal file
View File

@@ -0,0 +1,90 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [2.0.0] - 2023-07-31
### Added
- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE.
- Application Directories to enable PrimAITE as a Python package with predefined directories for storage.
- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib.
- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`.
- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options.
- Session loading to revisit previously run sessions for SB3 Agents.
- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface.
- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs:
1. Session Metadata
2. Results
3. Diagrams
4. Saved agents (training checkpoints and a final trained agent).
- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup.
- Benchmarking of PrimAITE performance, showcasing session and step durations for reference.
- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`.
### Changed
- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions.
- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`.
- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names.
- Docs and Tests now sit outside the `src/` directory.
- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages.
- All dependencies are now defined in the `pyproject.toml` file.
- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each.
- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management.
- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations.
- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority.
### Fixed
- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation.
- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes.
- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands.
- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making.
## [1.1.1] - 2023-06-27
### Bug Fixes
* Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by:
* Implementing a reference copy of all green IERs (`self.green_iers_reference`).
* Clearing the traffic on reference IERs at the same time as the live IERs.
* Passing the `green_iers_reference` to the `apply_iers` function at the reference stage.
* Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`.
* Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment.
* Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes.
* Removing the unnecessary "Reapply PoL and IERs" action from the step function.
* Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function.
## [1.1.0] - 2023-03-13
### Added
* The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function.
* The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name.
* Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components:
* Blue agent observation space;
* Blue agent action space;
* Reward function;
* Node pattern-of-life.
* The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network).
* New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile.
* NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability.
* The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value.
* The User Guide has been updated to reflect all the above changes.
### Changed
* "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing.
* "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
* "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
* "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file.
* Updates to Transactions.
### Fixed
* Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default.
[unreleased]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/compare/v2.0.0...HEAD
[2.0.0]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/tag/v2.0.0

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 - 2025 Defence Science and Technology Laboratory UK (https://dstl.gov.uk)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

521
diagram/classes.puml Normal file
View File

@@ -0,0 +1,521 @@
@startuml classes
set namespaceSeparator none
class "ACLRule" as primaite.acl.acl_rule.ACLRule {
dest_ip : str
permission
port : str
protocol : str
source_ip : str
get_dest_ip() -> str
get_permission() -> str
get_port() -> str
get_protocol() -> str
get_source_ip() -> str
}
class "AbstractObservationComponent" as primaite.environment.observations.AbstractObservationComponent {
current_observation : NotImplementedType, ndarray
env : str
space : Space
structure : List[str]
{abstract}generate_structure() -> List[str]
{abstract}update() -> None
}
class "AccessControlList" as primaite.acl.access_control_list.AccessControlList {
acl
acl_implicit_permission
acl_implicit_rule
max_acl_rules : int
add_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: str) -> None
check_address_match(_rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool
get_dictionary_hash(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int
get_relevant_rules(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> Dict[int, ACLRule]
is_blocked(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool
remove_all_rules() -> None
remove_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None
}
class "AccessControlList_" as primaite.environment.observations.AccessControlList_ {
current_observation : ndarray
space : MultiDiscrete
structure : list
generate_structure() -> List[str]
update() -> None
}
class "ActiveNode" as primaite.nodes.active_node.ActiveNode {
file_system_action_count : int
file_system_scanning : bool
file_system_scanning_count : int
file_system_state_actual : GOOD
file_system_state_observed : REPAIRING, RESTORING, GOOD
ip_address : str
patching_count : int
software_state
software_state : GOOD
set_file_system_state(file_system_state: FileSystemState) -> None
set_file_system_state_if_not_compromised(file_system_state: FileSystemState) -> None
set_software_state_if_not_compromised(software_state: SoftwareState) -> None
start_file_system_scan() -> None
update_booting_status() -> None
update_file_system_state() -> None
update_os_patching_status() -> None
update_resetting_status() -> None
}
class "AgentSessionABC" as primaite.agents.agent_abc.AgentSessionABC {
checkpoints_path
evaluation_path
is_eval : bool
learning_path
sb3_output_verbose_level : NONE
session_path : Union[str, Path]
session_timestamp : datetime
timestamp_str
uuid
close() -> None
{abstract}evaluate() -> None
{abstract}export() -> None
{abstract}learn() -> None
load(path: Union[str, Path]) -> None
{abstract}save() -> None
}
class "DoNothingACLAgent" as primaite.agents.simple.DoNothingACLAgent {
}
class "DoNothingNodeAgent" as primaite.agents.simple.DoNothingNodeAgent {
}
class "DummyAgent" as primaite.agents.simple.DummyAgent {
}
class "HardCodedACLAgent" as primaite.agents.hardcoded_acl.HardCodedACLAgent {
get_allow_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule]
get_allow_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
get_blocked_green_iers(green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[str, IER]
get_blocking_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
get_deny_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule]
get_matching_acl_rules(source_node_id: str, dest_node_id: str, protocol: str, port: str, acl: AccessControlList, nodes: Dict[str, Union[ServiceNode, ActiveNode]], services_list: List[str]) -> Dict[int, ACLRule]
get_matching_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
}
class "HardCodedAgentSessionABC" as primaite.agents.hardcoded_abc.HardCodedAgentSessionABC {
is_eval : bool
evaluate() -> None
export() -> None
learn() -> None
load(path: Union[str, Path]) -> None
save() -> None
}
class "HardCodedNodeAgent" as primaite.agents.hardcoded_node.HardCodedNodeAgent {
}
class "IER" as primaite.pol.ier.IER {
dest_node_id : str
end_step : int
id : str
load : int
mission_criticality : int
port : str
protocol : str
running : bool
source_node_id : str
start_step : int
get_dest_node_id() -> str
get_end_step() -> int
get_id() -> str
get_is_running() -> bool
get_load() -> int
get_mission_criticality() -> int
get_port() -> str
get_protocol() -> str
get_source_node_id() -> str
get_start_step() -> int
set_is_running(_value: bool) -> None
}
class "Link" as primaite.links.link.Link {
bandwidth : int
dest_node_name : str
id : str
protocol_list : List[Protocol]
source_node_name : str
add_protocol(_protocol: str) -> None
add_protocol_load(_protocol: str, _load: int) -> None
clear_traffic() -> None
get_bandwidth() -> int
get_current_load() -> int
get_dest_node_name() -> str
get_id() -> str
get_protocol_list() -> List[Protocol]
get_source_node_name() -> str
}
class "LinkTrafficLevels" as primaite.environment.observations.LinkTrafficLevels {
current_observation : ndarray
space : MultiDiscrete
structure : list
generate_structure() -> List[str]
update() -> None
}
class "Node" as primaite.nodes.node.Node {
booting_count : int
config_values
hardware_state : BOOTING, ON, RESETTING, OFF
name : Final[str]
node_id : Final[str]
node_type : Final[NodeType]
priority
resetting_count : int
shutting_down_count : int
reset() -> None
turn_off() -> None
turn_on() -> None
update_booting_status() -> None
update_resetting_status() -> None
update_shutdown_status() -> None
}
class "NodeLinkTable" as primaite.environment.observations.NodeLinkTable {
current_observation : ndarray
space : Box
structure : list
generate_structure() -> List[str]
update() -> None
}
class "NodeStateInstructionGreen" as primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen {
end_step : int
id : str
node_id : str
node_pol_type : str
service_name : str
start_step : int
state : Union['HardwareState', 'SoftwareState', 'FileSystemState']
get_end_step() -> int
get_node_id() -> str
get_node_pol_type() -> 'NodePOLType'
get_service_name() -> str
get_start_step() -> int
get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState']
}
class "NodeStateInstructionRed" as primaite.nodes.node_state_instruction_red.NodeStateInstructionRed {
end_step : int
id : str
initiator : str
pol_type
service_name : str
source_node_id : str
source_node_service : str
source_node_service_state : str
start_step : int
state : Union['HardwareState', 'SoftwareState', 'FileSystemState']
target_node_id : str
get_end_step() -> int
get_initiator() -> 'NodePOLInitiator'
get_pol_type() -> NodePOLType
get_service_name() -> str
get_source_node_id() -> str
get_source_node_service() -> str
get_source_node_service_state() -> str
get_start_step() -> int
get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState']
get_target_node_id() -> str
}
class "NodeStatuses" as primaite.environment.observations.NodeStatuses {
current_observation : ndarray
space : MultiDiscrete
structure : list
generate_structure() -> List[str]
update() -> None
}
class "ObservationsHandler" as primaite.environment.observations.ObservationsHandler {
current_observation
registered_obs_components : List[AbstractObservationComponent]
space
deregister(obs_component: AbstractObservationComponent) -> None
describe_structure() -> List[str]
from_config(env: 'Primaite', obs_space_config: dict) -> 'ObservationsHandler'
register(obs_component: AbstractObservationComponent) -> None
update_obs() -> None
update_space() -> None
}
class "PassiveNode" as primaite.nodes.passive_node.PassiveNode {
ip_address
}
class "Primaite" as primaite.environment.primaite_env.Primaite {
ACTION_SPACE_ACL_ACTION_VALUES : int
ACTION_SPACE_ACL_PERMISSION_VALUES : int
ACTION_SPACE_NODE_ACTION_VALUES : int
ACTION_SPACE_NODE_PROPERTY_VALUES : int
acl
action_dict : dict, Dict[int, List[int]]
action_space : Discrete, Space
action_type : int
actual_episode_count
agent_identifier
average_reward : float
env_obs : ndarray, tuple
episode_av_reward_writer
episode_count : int
episode_steps : int
green_iers : Dict[str, IER]
green_iers_reference : Dict[str, IER]
lay_down_config
links : Dict[str, Link]
links_post_blue : dict
links_post_pol : dict
links_post_red : dict
links_reference : Dict[str, Link]
max_number_acl_rules : int
network : Graph
network_reference : Graph
node_pol : Dict[str, NodeStateInstructionGreen]
nodes : Dict[str, NodeUnion]
nodes_post_blue : dict
nodes_post_pol : dict
nodes_post_red : dict
nodes_reference : Dict[str, NodeUnion]
num_links : int
num_nodes : int
num_ports : int
num_services : int
obs_config : dict
obs_handler
observation_space : Tuple, Box, Space
observation_type
ports_list : List[str]
red_iers : Dict[str, IER], dict
red_node_pol : dict, Dict[str, NodeStateInstructionRed]
services_list : List[str]
session_path : Final[Path]
step_count : int
step_info : Dict[Any]
timestamp_str : Final[str]
total_reward : float
total_step_count : int
training_config
transaction_writer
apply_actions_to_acl(_action: int) -> None
apply_actions_to_nodes(_action: int) -> None
apply_time_based_updates() -> None
close() -> None
create_acl_action_dict() -> Dict[int, List[int]]
create_acl_rule(item: Dict) -> None
create_green_ier(item: Dict) -> None
create_green_pol(item: Dict) -> None
create_link(item: Dict) -> None
create_node(item: Dict) -> None
create_node_action_dict() -> Dict[int, List[int]]
create_node_and_acl_action_dict() -> Dict[int, List[int]]
create_ports_list(ports: Dict) -> None
create_red_ier(item: Dict) -> None
create_red_pol(item: Dict) -> None
create_services_list(services: Dict) -> None
get_action_info(action_info: Dict) -> None
get_observation_info(observation_info: Dict) -> None
init_acl() -> None
init_observations() -> Tuple[spaces.Space, np.ndarray]
interpret_action_and_apply(_action: int) -> None
load_lay_down_config() -> None
output_link_status() -> None
reset() -> np.ndarray
reset_environment() -> None
reset_node(item: Dict) -> None
save_obs_config(obs_config: dict) -> None
set_as_eval() -> None
step(action: int) -> Tuple[np.ndarray, float, bool, Dict]
update_environent_obs() -> None
}
class "PrimaiteSession" as primaite.primaite_session.PrimaiteSession {
evaluation_path : Optional[Path], Path
is_load_session : bool
learning_path : Optional[Path], Path
session_path : Optional[Path], Path
timestamp_str : str, Optional[str]
close() -> None
evaluate() -> None
learn() -> None
setup() -> None
}
class "Protocol" as primaite.common.protocol.Protocol {
load : int
name : str
add_load(_load: int) -> None
clear_load() -> None
get_load() -> int
get_name() -> str
}
class "RLlibAgent" as primaite.agents.rllib.RLlibAgent {
{abstract}evaluate() -> None
{abstract}export() -> None
learn() -> None
{abstract}load(path: Union[str, Path]) -> RLlibAgent
save(overwrite_existing: bool) -> None
}
class "RandomAgent" as primaite.agents.simple.RandomAgent {
}
class "SB3Agent" as primaite.agents.sb3.SB3Agent {
is_eval : bool
evaluate() -> None
{abstract}export() -> None
learn() -> None
save() -> None
}
class "Service" as primaite.common.service.Service {
name : str
patching_count : int
port : str
software_state : GOOD
reduce_patching_count() -> None
}
class "ServiceNode" as primaite.nodes.service_node.ServiceNode {
services : Dict[str, Service]
add_service(service: Service) -> None
get_service_state(protocol_name: str) -> SoftwareState
has_service(protocol_name: str) -> bool
service_is_overwhelmed(protocol_name: str) -> bool
service_running(protocol_name: str) -> bool
set_service_state(protocol_name: str, software_state: SoftwareState) -> None
set_service_state_if_not_compromised(protocol_name: str, software_state: SoftwareState) -> None
update_booting_status() -> None
update_resetting_status() -> None
update_services_patching_status() -> None
}
class "SessionOutputWriter" as primaite.utils.session_output_writer.SessionOutputWriter {
learning_session : bool
transaction_writer : bool
close() -> None
write(data: Union[Tuple, Transaction]) -> None
}
class "TrainingConfig" as primaite.config.training_config.TrainingConfig {
action_type
agent_framework
agent_identifier
agent_load_file : Optional[str]
all_ok : float
checkpoint_every_n_episodes : int
compromised : float
compromised_should_be_good : float
compromised_should_be_overwhelmed : float
compromised_should_be_patching : float
corrupt : float
corrupt_should_be_destroyed : float
corrupt_should_be_good : float
corrupt_should_be_repairing : float
corrupt_should_be_restoring : float
deep_learning_framework
destroyed : float
destroyed_should_be_corrupt : float
destroyed_should_be_good : float
destroyed_should_be_repairing : float
destroyed_should_be_restoring : float
deterministic : bool
file_system_repairing_limit : int
file_system_restoring_limit : int
file_system_scanning_limit : int
good_should_be_compromised : float
good_should_be_corrupt : float
good_should_be_destroyed : float
good_should_be_overwhelmed : float
good_should_be_patching : float
good_should_be_repairing : float
good_should_be_restoring : float
green_ier_blocked : float
hard_coded_agent_view
implicit_acl_rule
load_agent : bool
max_number_acl_rules : int
node_booting_duration : int
node_reset_duration : int
node_shutdown_duration : int
num_eval_episodes : int
num_eval_steps : int
num_train_episodes : int
num_train_steps : int
observation_space : dict
observation_space_high_value : int
off_should_be_on : float
off_should_be_resetting : float
on_should_be_off : float
on_should_be_resetting : float
os_patching_duration : int
overwhelmed : float
overwhelmed_should_be_compromised : float
overwhelmed_should_be_good : float
overwhelmed_should_be_patching : float
patching : float
patching_should_be_compromised : float
patching_should_be_good : float
patching_should_be_overwhelmed : float
random_red_agent : bool
red_ier_running : float
repairing : float
repairing_should_be_corrupt : float
repairing_should_be_destroyed : float
repairing_should_be_good : float
repairing_should_be_restoring : float
resetting : float
resetting_should_be_off : float
resetting_should_be_on : float
restoring : float
restoring_should_be_corrupt : float
restoring_should_be_destroyed : float
restoring_should_be_good : float
restoring_should_be_repairing : float
sb3_output_verbose_level
scanning : float
seed : Optional[int]
service_patching_duration : int
session_type
time_delay : int
from_dict(config_dict: Dict[str, Any]) -> TrainingConfig
to_dict(json_serializable: bool) -> Dict
}
class "Transaction" as primaite.transactions.transaction.Transaction {
action_space : Optional[int]
agent_identifier
episode_number : int
obs_space : str
obs_space_description : NoneType, Optional[List[str]], list
obs_space_post : Optional[Union['np.ndarray', Tuple['np.ndarray']]]
obs_space_pre : Optional[Union['np.ndarray', Tuple['np.ndarray']]]
reward : Optional[float], float
step_number : int
timestamp : datetime
as_csv_data() -> Tuple[List, List]
}
primaite.agents.hardcoded_abc.HardCodedAgentSessionABC --|> primaite.agents.agent_abc.AgentSessionABC
primaite.agents.hardcoded_acl.HardCodedACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.hardcoded_node.HardCodedNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.rllib.RLlibAgent --|> primaite.agents.agent_abc.AgentSessionABC
primaite.agents.sb3.SB3Agent --|> primaite.agents.agent_abc.AgentSessionABC
primaite.agents.simple.DoNothingACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.simple.DoNothingNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.simple.DummyAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.simple.RandomAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.environment.observations.AccessControlList_ --|> primaite.environment.observations.AbstractObservationComponent
primaite.environment.observations.LinkTrafficLevels --|> primaite.environment.observations.AbstractObservationComponent
primaite.environment.observations.NodeLinkTable --|> primaite.environment.observations.AbstractObservationComponent
primaite.environment.observations.NodeStatuses --|> primaite.environment.observations.AbstractObservationComponent
primaite.nodes.active_node.ActiveNode --|> primaite.nodes.node.Node
primaite.nodes.passive_node.PassiveNode --|> primaite.nodes.node.Node
primaite.nodes.service_node.ServiceNode --|> primaite.nodes.active_node.ActiveNode
primaite.common.service.Service --|> primaite.nodes.service_node.ServiceNode
primaite.acl.access_control_list.AccessControlList --* primaite.environment.primaite_env.Primaite : acl
primaite.acl.acl_rule.ACLRule --* primaite.acl.access_control_list.AccessControlList : acl_implicit_rule
primaite.agents.hardcoded_acl.HardCodedACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.hardcoded_node.HardCodedNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.rllib.RLlibAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.sb3.SB3Agent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.DoNothingACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.DoNothingNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.DummyAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.RandomAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.config.training_config.TrainingConfig --* primaite.agents.agent_abc.AgentSessionABC : _training_config
primaite.config.training_config.TrainingConfig --* primaite.environment.primaite_env.Primaite : training_config
primaite.environment.observations.ObservationsHandler --* primaite.environment.primaite_env.Primaite : obs_handler
primaite.environment.primaite_env.Primaite --* primaite.agents.agent_abc.AgentSessionABC : _env
primaite.environment.primaite_env.Primaite --* primaite.agents.hardcoded_abc.HardCodedAgentSessionABC : _env
primaite.environment.primaite_env.Primaite --* primaite.agents.sb3.SB3Agent : _env
primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : episode_av_reward_writer
primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : transaction_writer
primaite.config.training_config.TrainingConfig --o primaite.nodes.node.Node : config_values
primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen --* primaite.environment.primaite_env.Primaite
primaite.nodes.node_state_instruction_red.NodeStateInstructionRed --* primaite.environment.primaite_env.Primaite
primaite.pol.ier.IER --* primaite.environment.primaite_env.Primaite
primaite.common.protocol.Protocol --o primaite.links.link.Link
primaite.links.link.Link --* primaite.environment.primaite_env.Primaite
primaite.config.training_config.TrainingConfig --o primaite.nodes.active_node.ActiveNode
primaite.utils.session_output_writer.SessionOutputWriter --> primaite.transactions.transaction.Transaction
primaite.transactions.transaction.Transaction --> primaite.environment.primaite_env.Primaite
@enduml

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
..
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
@@ -19,8 +19,8 @@ sys.path.insert(0, os.path.abspath("../"))
# -- Project information -----------------------------------------------------
year = datetime.datetime.now().year
project = "PrimAITE"
copyright = f"Copyright (C) QinetiQ Training and Simulation Ltd 2021 - {year}"
author = "QinetiQ Training and Simulation Ltd"
copyright = f"Copyright (C) Defence Science and Technology Laboratory UK 2021 - {year}"
author = "Defence Science and Technology Laboratory UK"
# The short Major.Minor.Build version
with open("../src/primaite/VERSION", "r") as file:

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Welcome to PrimAITE's documentation
====================================

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _about:

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _config:
@@ -17,10 +17,10 @@ PrimAITE uses two configuration files for its operation:
Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules.
Environment Config:
Training Config:
*******************
The environment config file consists of the following attributes:
The Training Config file consists of the following attributes:
**Generic Config Values**

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Custom Agents
=============
@@ -130,7 +130,7 @@ Finally, specify your agent in your training config.
.. code-block:: yaml
# ~/primaite/config/path/to/your/config_main.yaml
# ~/primaite/2.0.0rc2/config/path/to/your/config_main.yaml
# Training Config File

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. role:: raw-html(raw)
:format: html

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _getting-started:
@@ -41,12 +41,12 @@ Install PrimAITE
.. code-tab:: bash
:caption: Unix
mkdir ~/primaite
mkdir ~/primaite/2.0.0rc2
.. code-tab:: powershell
:caption: Windows (Powershell)
mkdir ~\primaite
mkdir ~\primaite\2.0.0rc2
2. Navigate to the primaite directory and create a new python virtual environment (venv)
@@ -55,13 +55,13 @@ Install PrimAITE
.. code-tab:: bash
:caption: Unix
cd ~/primaite
cd ~/primaite/2.0.0rc2
python3 -m venv .venv
.. code-tab:: powershell
:caption: Windows (Powershell)
cd ~\primaite
cd ~\primaite\2.0.0rc2
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Glossary
=============
@@ -77,5 +77,5 @@ Glossary
Gym
PrimAITE uses the Gym reinforcement learning framework API to create a training environment and interface with RL agents. Gym defines a common way of creating observations, actions, and rewards.
User data directory
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite` on linux/darwin and `C:\Users\<username>\primaite` on Windows.
User app home
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\Users\<username>\primaite\<version>` on Windows.

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
v1.2 to v2.0 Migration guide
============================
@@ -31,7 +31,7 @@ v1.2 to v2.0 Migration guide
**3. Location of configs**
In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/config``. On Windows, this is stored in ``C:\Users\<your username>\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here.
In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/2.0.0rc2/config``. On Windows, this is stored in ``C:\Users\<your username>\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here.
**4. Contents of configs**

View File

@@ -1,6 +1,6 @@
.. only:: comment
Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _run a primaite session:
@@ -20,16 +20,16 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite
cd ~/primaite/2.0.0rc2
source ./.venv/bin/activate
primaite session ./config/my_training_config.yaml ./config/my_lay_down_config.yaml
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
.. code-tab:: powershell
:caption: Powershell CLI
cd ~\primaite
cd ~\primaite\2.0.0rc2
.\.venv\Scripts\activate
primaite session .\config\my_training_config.yaml .\config\my_lay_down_config.yaml
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
.. code-tab:: python
@@ -41,11 +41,13 @@ Both the ``primaite session`` and :func:`primaite.main.run` take a training conf
lay_down_config = <path to lay down config yaml file>
run(training_config, lay_down_config)
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/sessions``).
The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0rc2/sessions``).
The sub-directory is formatted as such: ``~/primaite/2.0.0rc2/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
``~/primaite/2.0.0rc2/sessions/2023-01-31/2023-01-31_17-30-00/``.
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
Outputs
@@ -108,43 +110,44 @@ For each training session, assuming the agent being trained implements the *save
~/
└── primaite/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
── evaluation/
├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
── sb3ppo_10.zip
│ ├── SB3_PPO.zip
── tensorboard_logs/
── PPO_1/
└── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
├── PPO_2/
└── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
├── PPO_3/
└── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
├── PPO_4/
└── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
├── PPO_5/
└── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
├── PPO_6/
└── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
├── PPO_7/
└── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
├── PPO_8/
└── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
├── PPO_9/
└── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
└── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
└── 2.0.0rc2/
└── sessions/
└── 2023-07-18/
── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
── checkpoints/
│ │ └── sb3ppo_10.zip
── SB3_PPO.zip
── tensorboard_logs/
├── PPO_1/
│ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
├── PPO_2/
│ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
├── PPO_3/
│ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
├── PPO_4/
│ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
├── PPO_5/
│ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
├── PPO_6/
│ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
├── PPO_7/
│ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
├── PPO_8/
│ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
├── PPO_9/
└── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
└── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
@@ -157,14 +160,14 @@ A previous session can be loaded by providing the **directory** of the previous
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite
cd ~/primaite/2.0.0rc2
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite
cd ~\primaite\2.0.0rc2
.\.venv\Scripts\activate
primaite session --load "path\to\session"

View File

@@ -5,13 +5,13 @@ build-backend = "setuptools.build_meta"
[project]
name = "primaite"
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
authors = [{name="QinetiQ Training and Simulation Ltd"}]
license = {text = "MIT License"}
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
license = {file = "LICENSE"}
requires-python = ">=3.8, <3.11"
dynamic = ["version", "readme"]
classifiers = [
"License :: MIT License",
"Development Status :: 4 - Beta",
"License :: OSI Approved :: MIT License",
"Development Status :: 5 - Production/Stable",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"Operating System :: POSIX :: Linux",
@@ -55,8 +55,10 @@ dev = [
"build==0.10.0",
"flake8==6.0.0",
"furo==2023.3.27",
"gputil==1.4.0",
"pip-licenses==4.3.0",
"pre-commit==2.20.0",
"pylatex==1.4.1",
"pytest==7.2.0",
"pytest-xdist==3.3.1",
"pytest-cov==4.0.0",

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from setuptools import setup
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa

View File

@@ -1 +1 @@
2.0.0rc1
2.0.0rc2

View File

@@ -1,23 +1,126 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import logging
import logging.config
import shutil
import sys
from bisect import bisect
from logging import Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Any, Dict, Final
from typing import Any, Dict, Final, List
import pkg_resources
import yaml
from platformdirs import PlatformDirs
_PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
"""An instance of `PlatformDirs` set with appname='primaite'."""
with open(Path(__file__).parent.resolve() / "VERSION", "r") as file:
__version__ = file.readline().strip()
class _PrimaitePaths:
"""
A Primaite paths class that leverages PlatformDirs.
The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`.
"""
def __init__(self):
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
return [k for k, v in class_items if isinstance(v, property)]
def mkdirs(self):
"""
Creates all Primaite directories.
Does this by retrieving all properties in the PrimaiteDirs class and calls each one.
"""
for p in self._get_dirs_properties():
getattr(self, p)
@property
def user_home_path(self) -> Path:
"""The PrimAITE user home path."""
path = Path.home() / "primaite" / __version__
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_sessions_path(self) -> Path:
"""The PrimAITE user sessions path."""
path = self.user_home_path / "sessions"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_config_path(self) -> Path:
"""The PrimAITE user config path."""
path = self.user_home_path / "config"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_notebooks_path(self) -> Path:
"""The PrimAITE user notebooks path."""
path = self.user_home_path / "notebooks"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_home_path(self) -> Path:
"""The PrimAITE app home path."""
path = self._dirs.user_data_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_dir_path(self) -> Path:
"""The PrimAITE app config directory path."""
path = self._dirs.user_config_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_file_path(self) -> Path:
"""The PrimAITE app config file path."""
return self.app_config_dir_path / "primaite_config.yaml"
@property
def app_log_dir_path(self) -> Path:
"""The PrimAITE app log directory path."""
if sys.platform == "win32":
path = self.app_home_path / "logs"
else:
path = self._dirs.user_log_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_log_file_path(self) -> Path:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"
def __repr__(self):
properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()])
return f"{self.__class__.__name__}({properties_str})"
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
def _host_primaite_config():
if not PRIMAITE_PATHS.app_config_file_path.exists():
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_host_primaite_config()
def _get_primaite_config() -> Dict:
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
config_path = PRIMAITE_PATHS.app_config_file_path
if not config_path.exists():
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
with open(config_path, "r") as file:
@@ -36,35 +139,7 @@ def _get_primaite_config() -> Dict:
_PRIMAITE_CONFIG = _get_primaite_config()
_USER_DIRS: Final[Path] = Path.home() / "primaite"
"""The users home space for PrimAITE which is located at: ~/primaite."""
NOTEBOOKS_DIR: Final[Path] = _USER_DIRS / "notebooks"
"""
The path to the users notebooks directory as an instance of `Path` or
`PosixPath`, depending on the OS.
Users notebooks are stored at: ``~/primaite/notebooks``.
"""
USERS_CONFIG_DIR: Final[Path] = _USER_DIRS / "config"
"""
The path to the users config directory as an instance of `Path` or
`PosixPath`, depending on the OS.
Users config files are stored at: ``~/primaite/config``.
"""
SESSIONS_DIR: Final[Path] = _USER_DIRS / "sessions"
"""
The path to the users PrimAITE Sessions directory as an instance of `Path` or
`PosixPath`, depending on the OS.
Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
"""
# region Setup Logging
class _LevelFormatter(Formatter):
"""
A custom level-specific formatter.
@@ -87,14 +162,6 @@ class _LevelFormatter(Formatter):
return formatter.format(record)
def _log_dir() -> Path:
if sys.platform == "win32":
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
else:
dir_path = _PLATFORM_DIRS.user_log_path
return dir_path
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
@@ -105,18 +172,10 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
}
)
LOG_DIR: Final[Path] = _log_dir()
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
LOG_DIR.mkdir(exist_ok=True, parents=True)
LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
filename=LOG_PATH,
filename=PRIMAITE_PATHS.app_log_file_path,
maxBytes=10485760, # 10MB
backupCount=9, # Max 100MB of logs
encoding="utf8",
@@ -146,10 +205,3 @@ def getLogger(name: str) -> Logger: # noqa
logger.setLevel(_PRIMAITE_CONFIG["log_level"])
return logger
# endregion
with open(Path(__file__).parent.resolve() / "VERSION", "r") as file:
__version__ = file.readline()

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Access Control List. Models firewall functionality."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A class that implements the access control list implementation for the network."""
import logging
from typing import Dict, Final, List, Union

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Common interface between RL agents from different libraries and PrimAITE."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Union
from uuid import uuid4
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite import getLogger, PRIMAITE_PATHS
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
@@ -25,14 +25,14 @@ def get_session_path(session_timestamp: datetime) -> Path:
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
~/primaite/2.0.0rc2/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_path
session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
@@ -230,7 +230,6 @@ class AgentSessionABC(ABC):
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
self._plot_av_reward_per_episode(learning_session=True)
@abstractmethod
def evaluate(
@@ -243,9 +242,9 @@ class AgentSessionABC(ABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_evaluate:
self._plot_av_reward_per_episode(learning_session=False)
self._update_session_metadata_file()
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
@abstractmethod

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import time
from abc import abstractmethod
from pathlib import Path

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from typing import Dict, List, Union
import numpy as np

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC

View File

@@ -1,8 +1,9 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
import shutil
import zipfile
from datetime import datetime
from logging import Logger
from pathlib import Path
@@ -17,8 +18,9 @@ from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
from primaite.environment.primaite_env import Primaite
from primaite.exceptions import RLlibAgentError
_LOGGER: Logger = getLogger(__name__)
@@ -68,11 +70,14 @@ class RLlibAgent(AgentSessionABC):
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.error(msg)
print(msg)
raise NotImplementedError
_LOGGER.critical(msg)
raise NotImplementedError(msg)
super().__init__(training_config_path, lay_down_config_path)
if self._training_config.session_type == SessionType.EVAL:
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
_LOGGER.critical(msg)
raise RLlibAgentError(msg)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -98,6 +103,7 @@ class RLlibAgent(AgentSessionABC):
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
self._train_agent = None # Required to capture the learning agent to close after eval
def _update_session_metadata_file(self) -> None:
"""
@@ -179,20 +185,73 @@ class RLlibAgent(AgentSessionABC):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent = self._agent
else:
self._agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
def _unpack_saved_agent_into_eval(self) -> Path:
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
agent_restore_path = self.evaluation_path / "agent_restore"
if agent_restore_path.exists():
shutil.rmtree(agent_restore_path)
agent_restore_path.mkdir()
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
zip_file.extractall(agent_restore_path)
return agent_restore_path
def _setup_eval(self):
self._can_learn = False
self._can_evaluate = True
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
def evaluate(
self,
**kwargs: None,
) -> None:
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._setup_eval()
self._env: Primaite = Primaite(
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
)
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action = self._agent.compute_single_action(observation=obs, explore=False)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
# Now we're safe to close the learning agent and write the mean rewards per episode for it
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# Perform a clean-up of the unpacked agent
if (self.evaluation_path / "agent_restore").exists():
shutil.rmtree((self.evaluation_path / "agent_restore"))
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
@@ -153,6 +153,8 @@ class SB3Agent(AgentSessionABC):
# save agent
self.save()
self._plot_av_reward_per_episode(learning_session=True)
def evaluate(
self,
**kwargs: Any,

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import numpy as np

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from typing import Dict, List, Union
import numpy as np

View File

@@ -1,18 +1,15 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Provides a CLI using Typer as an entry point."""
import logging
import os
import shutil
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
import typer
import yaml
from platformdirs import PlatformDirs
from typing_extensions import Annotated
from primaite import PRIMAITE_PATHS
from primaite.data_viz import PlotlyTemplate
app = typer.Typer()
@@ -47,10 +44,10 @@ def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
"""
import re
from primaite import LOG_PATH
from primaite import PRIMAITE_PATHS
if os.path.isfile(LOG_PATH):
with open(LOG_PATH) as file:
if os.path.isfile(PRIMAITE_PATHS.app_log_file_path):
with open(PRIMAITE_PATHS.app_log_file_path) as file:
lines = file.readlines()
for line in lines[-last_n:]:
print(re.sub(r"\n*", "", line))
@@ -70,16 +67,13 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) ->
For example, to set the to debug, call: primaite log-level DEBUG
"""
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
if user_config_path.exists():
with open(user_config_path, "r") as file:
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if level:
primaite_config["logging"]["log_level"] = level.value
with open(user_config_path, "w") as file:
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE Log Level: {level}")
else:
@@ -118,16 +112,8 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
# Does this way to avoid using PrimAITE package before config is loaded
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, user_config_path)
from primaite import getLogger
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
_LOGGER = getLogger(__name__)
@@ -136,7 +122,7 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building the PrimAITE app directories...")
setup_app_dirs.run()
PRIMAITE_PATHS.mkdirs()
_LOGGER.info("Rebuilding the demo notebooks...")
reset_demo_notebooks.run(overwrite_existing=True)
@@ -157,11 +143,11 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/config/example_config/training/training_config_main.yaml.
~/primaite/2.0.0rc2/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
~/primaite/2.0.0rc2/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
@@ -173,15 +159,18 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
from primaite.main import run
if load is not None:
# run a loaded session
run(session_path=load)
if not tc:
tc = main_training_config_path()
else:
# start a new session using tc and ldc
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
run(training_config_path=tc, lay_down_config_path=ldc)
run(training_config_path=tc, lay_down_config_path=ldc)
@app.command()
@@ -195,16 +184,13 @@ def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
"""
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
if user_config_path.exists():
with open(user_config_path, "r") as file:
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
with open(user_config_path, "w") as file:
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Objects which are shared between many PrimAITE modules."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Enumerations for APE."""
from enum import Enum, IntEnum

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The protocol class."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Service class."""
from primaite.common.enums import SoftwareState

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Configuration parameters for running experiments."""

View File

@@ -1,15 +1,15 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Union
import yaml
from primaite import getLogger, USERS_CONFIG_DIR
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
from dataclasses import dataclass, field
@@ -8,7 +8,7 @@ from typing import Any, Dict, Final, Optional, Union
import yaml
from primaite import getLogger, USERS_CONFIG_DIR
from primaite import getLogger, PRIMAITE_PATHS
from primaite.common.enums import (
ActionType,
AgentFramework,
@@ -22,7 +22,7 @@ from primaite.common.enums import (
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training"
def main_training_config_path() -> Path:
@@ -246,6 +246,7 @@ class TrainingConfig:
return data
def __str__(self) -> str:
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
tc = f"{self.agent_framework}, "
if self.agent_framework is AgentFramework.RLLIB:
tc += f"{self.deep_learning_framework}, "
@@ -253,7 +254,7 @@ class TrainingConfig:
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={self.observation_space}, "
tc += f"observation_space={obs_str}, "
if self.session_type is SessionType.TRAIN:
tc += f"{self.num_train_episodes} episodes @ "
tc += f"{self.num_train_steps} steps"

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Dict, Optional, Union
@@ -7,13 +7,12 @@ import polars as pl
import yaml
from plotly.graph_objs import Figure
from primaite import _PLATFORM_DIRS
from primaite import PRIMAITE_PATHS
def _get_plotly_config() -> Dict:
def get_plotly_config() -> Dict:
"""Get the plotly config from primaite_config.yaml."""
user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
with open(user_config_path, "r") as file:
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["session"]["outputs"]["plots"]
@@ -41,7 +40,7 @@ def plot_av_reward_per_episode(
if subtitle:
title = subtitle
config = _get_plotly_config()
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import logging
@@ -261,10 +261,12 @@ class Primaite(Env):
self, transaction_writer=True, learning_session=True
)
self.is_eval = False
@property
def actual_episode_count(self) -> int:
"""Shifts the episode_count by -1 for RLlib."""
if self.training_config.agent_framework is AgentFramework.RLLIB:
"""Shifts the episode_count by -1 for RLlib learning session."""
if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
return self.episode_count - 1
return self.episode_count
@@ -276,6 +278,7 @@ class Primaite(Env):
self.step_count = 0
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
self.is_eval = True
def _write_av_reward_per_episode(self) -> None:
if self.actual_episode_count > 0:

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements reward function."""
from logging import Logger
from typing import Dict, TYPE_CHECKING, Union

View File

@@ -0,0 +1,11 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
class PrimaiteError(Exception):
"""The root PrimAITe Error."""
pass
class RLlibAgentError(PrimaiteError):
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
pass

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Network connections between nodes in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The link class."""
from typing import List

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Nodes represent network hosts in the simulation."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The base Node class."""
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union

View File

@@ -1,6 +1,5 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from primaite.common.enums import NodePOLType
@@ -9,8 +8,7 @@ if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
@dataclass()
class NodeStateInstructionRed(object):
class NodeStateInstructionRed:
"""The Node State Instruction class."""
def __init__(

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
import importlib.util
@@ -7,7 +7,7 @@ import subprocess
import sys
from logging import Logger
from primaite import getLogger, NOTEBOOKS_DIR
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
@@ -26,7 +26,7 @@ def start_jupyter_session() -> None:
jupyter_cmd = "jupyter lab"
working_dir = os.getcwd()
os.chdir(NOTEBOOKS_DIR)
os.chdir(PRIMAITE_PATHS.user_notebooks_path)
subprocess.Popen(jupyter_cmd)
os.chdir(working_dir)
else:

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Pattern of Life- Represents the actions of users on the network."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""
Information Exchange Requirements for APE.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict
@@ -250,6 +250,11 @@ def apply_red_agent_node_pol(
# continue --------------------------
target_node: NodeUnion = nodes[target_node_id]
# check if the initiator type is a str, and if so, cast it as
# NodePOLInitiator
if isinstance(initiator, str):
initiator = NodePOLInitiator[initiator]
# Based the action taken on the initiator type
if initiator == NodePOLInitiator.DIRECT:
# No conditions required, just apply the change

View File

@@ -1,9 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
from typing import Any, Dict, Final, Optional, Tuple, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
@@ -16,6 +17,7 @@ from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, S
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.utils.session_metadata_parser import parse_session_metadata
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
_LOGGER = getLogger(__name__)
@@ -70,13 +72,7 @@ class PrimaiteSession:
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) # noqa
def setup(self) -> None:
"""Performs the session setup."""
@@ -186,3 +182,28 @@ class PrimaiteSession:
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
return json.load(file)

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utilities to prepare the user's data folders."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from primaite import getLogger

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import filecmp
import os
import shutil
@@ -7,7 +7,7 @@ from pathlib import Path
import pkg_resources
from primaite import getLogger, NOTEBOOKS_DIR
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
@@ -23,7 +23,7 @@ def run(overwrite_existing: bool = True) -> None:
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
target_fp = NOTEBOOKS_DIR / Path(*path_split)
target_fp = PRIMAITE_PATHS.user_notebooks_path / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import filecmp
import os
import shutil
@@ -6,7 +6,7 @@ from pathlib import Path
import pkg_resources
from primaite import getLogger, USERS_CONFIG_DIR
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER = getLogger(__name__)
@@ -23,7 +23,7 @@ def run(overwrite_existing: bool = True) -> None:
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
target_fp = PRIMAITE_PATHS.user_config_path / "example_config" / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()

View File

@@ -1,29 +0,0 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from logging import Logger
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
_LOGGER: Logger = getLogger(__name__)
def run() -> None:
"""
Handles creation of application directories and user directories.
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
app directories in the correct locations based on the users OS.
"""
app_dirs = [
_USER_DIRS,
NOTEBOOKS_DIR,
LOG_DIR,
]
for app_dir in app_dirs:
if not app_dir.is_dir():
app_dir.mkdir(parents=True, exist_ok=True)
_LOGGER.info(f"Created directory: {app_dir}")
if __name__ == "__main__":
run()

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Record data of the system's state and agent's observations and actions."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Transaction class."""
from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING, Union

View File

@@ -1,2 +1,2 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utilities for PrimAITE."""

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import os
from logging import Logger
from pathlib import Path

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
from pathlib import Path
from typing import Any, Dict, Union

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union
@@ -18,7 +18,7 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import csv
from logging import Logger
from typing import Final, List, Tuple, TYPE_CHECKING, Union

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Final

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Main Config File
# Generic config values

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Main Config File
# Generic config values

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
- item_type: PORTS
ports_list:
- port: '21'

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.
@@ -69,7 +69,7 @@ num_train_episodes: 10
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
num_eval_episodes: 3
# Number of time_steps for evaluation per episode
num_eval_steps: 256

View File

@@ -0,0 +1,164 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 3
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.

View File

@@ -1,11 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime
import json
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Union
from unittest.mock import patch
import pytest
@@ -13,7 +12,6 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -37,31 +35,6 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
return json.load(file)
@property
def env(self) -> Primaite:
"""Direct access to the env for ease of testing."""

View File

@@ -1 +1 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import tempfile
from datetime import datetime
from pathlib import Path

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Used to tes the ACL functions."""
from primaite.acl.access_control_list import AccessControlList

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Used to test Active Node functions."""
import pytest

View File

@@ -1,4 +1,4 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Test env creation and behaviour with different observation spaces."""
import numpy as np

Some files were not shown because too many files have changed in this diff Show More