492 Commits

Author SHA1 Message Date
Chris McCarthy
4fb88c94e8 Merge pull request #9 from Autonomous-Resilient-Cyber-Defence/dev
updated primaite logo
2023-08-15 16:13:16 +01:00
Chris McCarthy
44de1531a5 Added PrimAITE_logo_transparent.png file 2023-08-15 13:56:06 +01:00
Chris McCarthy
5e526ba81e Merge pull request #7 from Autonomous-Resilient-Cyber-Defence/dev
v2.0.0
2023-08-15 13:29:06 +01:00
Chris McCarthy
506b0836ea Dropped sphinx-pipeline 2023-08-15 13:28:02 +01:00
ARCD
c2992c3348 Create spinx-pipeline 2023-08-15 13:23:25 +01:00
Chris McCarthy
aaf1b16912 Added sphinx docs build pipeline for GitHub pages on release 2023-08-15 11:26:15 +01:00
Chris McCarthy
6bf348e5c3 Added the DSTL MIT license and updated the license in pyproject.toml 2023-08-15 11:14:23 +01:00
Chris McCarthy
4ed31a52e5 Updated the What is PrimAITE? section in index.rst. Dropped the use of sphinx-code-tabs in the docs as building the docs in pdf (make latexpdf) is suddenly complaining about the tab buttons. 2023-08-03 16:04:23 +01:00
jamesshort1
efb230a5fc Update README.md 2023-07-31 09:17:48 +01:00
jamesshort1
da8f91c78c Update README.md 2023-07-31 09:16:24 +01:00
Chris McCarthy
3fae05d971 #1711 - Last minute docs changes 2023-07-28 14:41:39 +01:00
Chris McCarthy
5893ea8db4 Merge pull request #5 from Autonomous-Resilient-Cyber-Defence/dev
Finalised the support for legacy config files
2023-07-28 14:19:15 +01:00
Chris McCarthy
10ce2923c7 #1711 - Removed the legacy bools from the RLlibAgent constructor in primaite_session.py 2023-07-28 14:02:17 +01:00
Chris McCarthy
cd7ba9986c #1711 - Fully Integrated the legacy training config and lay down config options into the CLI, run PrimaiteSession, and Agent classes. Made the ese test in test_full_legacy_config_session.py use this new integrated option to read the legacy file. 2023-07-28 13:49:26 +01:00
Chris McCarthy
1084be914b Merge remote-tracking branch 'github/dev' into github_dev 2023-07-28 12:54:58 +01:00
Chris McCarthy
f01825b180 #1711 - Added the ability to load legacy lay down config files. Added extensive unit testing and end-to-end testing. Also added the ability to set exactly how many num_train_steps, num_eval_steps, num_train_episodes, and num_eval_episode and used when converting a legacy training config. 2023-07-28 12:53:49 +01:00
jamesshort1
5bf9f7f4ea Update README.md 2023-07-28 09:41:29 +01:00
jamesshort1
f71b3480f0 Update README.md 2023-07-28 09:40:05 +01:00
jamesshort1
87741ba994 Update README.md 2023-07-28 09:39:05 +01:00
jamesshort1
0a1d17c9cc Update README.md 2023-07-27 14:57:08 +01:00
Chris McCarthy
e647b35f6f Merge pull request #4 from Autonomous-Resilient-Cyber-Defence/dev
v2.0.0
2023-07-27 11:42:13 +01:00
Chris McCarthy
b15be9796d Added GFX license conditions. Included LICENSE file in build. Fixed a few character issues in README.md 2023-07-27 11:40:29 +01:00
Chris McCarthy
b40fb09c1f Dropped MIT license until public release 2023-07-27 11:03:25 +01:00
jamesshort1
99ff8ca4e1 Update README.md 2023-07-27 08:59:43 +01:00
jamesshort1
ebb901c2b2 Update README.md 2023-07-27 08:59:24 +01:00
Chris McCarthy
094c2380cd Merge pull request #3 from Autonomous-Resilient-Cyber-Defence/dev
v2.0.0 - Added run section with primaite session command in the README.md
2023-07-26 22:11:47 +01:00
Chris McCarthy
1b8af0d862 Added run section with primaite session command in the README.md 2023-07-26 22:10:59 +01:00
Chris McCarthy
0b93ca41ab Merge pull request #2 from Autonomous-Resilient-Cyber-Defence/dev
v2.0.0 Added additional install instructions to the README.md
2023-07-26 21:50:50 +01:00
Chris McCarthy
69b0ea4572 Added additional install instructions to the README.md 2023-07-26 21:49:36 +01:00
Chris McCarthy
0fc596e06b Merge pull request #1 from Autonomous-Resilient-Cyber-Defence/dev
v2.0.0
2023-07-26 21:30:15 +01:00
Chris McCarthy
a5fa613bea Added project urls to pyproject.toml and a setup.cfg file for PyPi to pickup author and url 2023-07-26 21:11:15 +01:00
Chris McCarthy
46699880ce Added a CONTRIBUTING.md and added a URL to the Yawning-Titan reference in index.rst 2023-07-26 20:17:29 +01:00
Chris McCarthy
22f72139e3 Dropped the ADF build files and updated the package name install step in python-package.yml. Added bug_report.md and feature_request.md files for GitHub 2023-07-26 20:05:44 +01:00
Chris McCarthy
eec1c25989 Updated the README.md with developer install specific instructions 2023-07-26 19:49:24 +01:00
Chris McCarthy
514f239cc6 Merge remote-tracking branch 'devops/main' into github_dev 2023-07-26 19:43:46 +01:00
Chris McCarthy
21931e991c Create python-package.yml CI pipeline 2023-07-26 19:38:28 +01:00
Christopher McCarthy
0c601d0383 Merged PR 142: v2.0.0
v2.0.0

Related work items: #901, #1523, #1574, #1594, #1595, #1597, #1623, #1626, #1629, #1631, #1632, #1635, #1637, #1638, #1639, #1640, #1641, #1647, #1648, #1650
2023-07-26 18:20:28 +00:00
Chris McCarthy
8e3a0a0afa Ran final v2.0.0 benchmark and dropped the release candidate benchmarks 2023-07-26 18:19:11 +01:00
Chris McCarthy
7fe5df7fc4 Bumped version to 2.0.0 2023-07-26 14:38:57 +01:00
Christopher McCarthy
8bf1440f9b Merged PR 139: Re-run the benchmarks for v2.0.0rc1 and v2.0.0rc2 using the same config file....
## Summary
Re-run the benchmarks for v2.0.0rc1 and v2.0.0rc2 using the same config file. As expected, the versions perform almost identically as there's no real logic changes that would affect the agent between the two release candidates

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1574
2023-07-25 09:33:32 +00:00
Chris McCarthy
e9b52a69b7 Merge remote-tracking branch 'origin/dev' into release/2.0.0rc1
# Conflicts:
#	.gitignore
#	benchmark/results/PrimAITE Versions Learning Benchmark.png
#	benchmark/results/v2.0.0rc1/PrimAITE v2.0.0rc1 Learning Benchmark.pdf
#	benchmark/results/v2.0.0rc1/PrimAITE v2.0.0rc1 Learning Benchmark.png
#	benchmark/results/v2.0.0rc1/v2.0.0rc1_benchmark_metadata.json
2023-07-24 22:50:34 +01:00
Chris McCarthy
4f16105b67 RE-ran the benchmarks for v2.0.0rc1 and v2.0.0rc2 using the same config file. As expected, the versions perform almost identically as there's no real logic changes that would affect the agent between the two release candidates 2023-07-24 22:43:22 +01:00
Christopher McCarthy
7e0f55cdb8 Merged PR 137: #1650 - Turned on the test. Also updated some references to the old primaite...
## Summary
- Turned on the test. Also updated some references to the old primaite paths. Finally, pushed the deployment status classifier to Development Status :: 5 - Production/Stable

## Test process
Yes, turned on the test.

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

#1650 - Turned on the test. Also updated some references to the old primaite paths. Finally, pushed the deployment status classifier to Development Status :: 5 - Production/Stable

Related work items: #1650
2023-07-24 11:11:17 +00:00
Chris McCarthy
31955c0c84 #1650 - Removed the commented out pytest.mark.skip statement now that the test does work. 2023-07-24 09:20:36 +01:00
Chris McCarthy
ef6585a298 #1650 - Turned on the test. Also updated some references to the old primaite paths. Finally, pushed the deployment status classifier to Development Status :: 5 - Production/Stable 2023-07-21 16:49:17 +01:00
Christopher McCarthy
5d9dd7a2d9 Merged PR 135: #1648 - Updated file headers
## Summary
 - Updated file header from 'Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.' to '© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK'

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

#1648 - Updated file header from 'Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.' to '© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK'

Related work items: #1648
2023-07-21 14:08:09 +00:00
Chris McCarthy
a39541623a #1648 - Reverted benchmark sessions and episodes numbers 2023-07-21 15:07:21 +01:00
Chris McCarthy
63297ef0ed #1648 - Added header to benchmark files 2023-07-21 15:06:05 +01:00
Chris McCarthy
4527b38aa6 #1648 - Reverted the benchmark files 2023-07-21 15:01:51 +01:00
Chris McCarthy
050ca68907 #1648 - Updated file header from 'Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.' to '© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK' 2023-07-21 14:54:09 +01:00
Christopher McCarthy
cda0a28c03 Merged PR 134: PrimAITE app and user dirs are version specific
## Summary
- Added _PrimaitePaths class that manages all the primaite locations using PlayformDirs. This class now creates new primaite locations for each version of primaite.
- Rolled the _PrimaitePaths class out throughout the code base.
- Updated the docs to reference the new version paths.
- Updated the author from qinetiq to dstl
- Bumped version number to 2.0.0rc2

## Test process
- Manual checks. Tough to test the install paths.

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

Related work items: #1647
2023-07-21 13:50:56 +00:00
Chris McCarthy
b8ca5f1dca Merge branch 'dev' into feature/1647_Append_version_number_to_the_primaite_root_dir 2023-07-21 14:01:45 +01:00
Chris McCarthy
196d8855c3 #1647 - Added _PrimaitePaths class that manages all the primaite locations using PlayformDirs. This class now creates new primaite locations for each version of primaite.
- Rolled the _PrimaitePaths class out throughout the code base.
- Updated the docs to reference the new version paths.
- Updated the author from qinetiq to dstl
- Bumped version number to 2.0.0rc2
2023-07-21 14:00:50 +01:00
Czar Echavez
e1a396981a Merged PR 130: #1595: load session double run
## Summary
- Fixed the bug where session gets run twice when loading a session via CLI
- Added a test for the CLI run - xskipped while the bugfix for load session acting odd is tbd
- Fixed a minor bug in PrimAITE session where session_path is overwritten

## Test process
Added a new test for CLI, but xskipped while a different bug is tbd

Ran it locally and no longer runs another session after the loaded session
```
(venv) PS D:\Projects\ARCD\PrimAITE\PrimAITE> primaite session --load [REDACTED for security]\primaite\sessions\2023-07-20\2023-07-20_15-01-11
2023-07-20 15:04:21,320: Using: AgentFramework.SB3, AgentIdentifier.PPO, ActionType.NODE, observation_space=NODE_LINK_TABLE, Training: 5 episodes @ 256 stepsEvaluation: 5 episodes @ 256 steps
2023-07-20 15:04:21,335: Environment configuration loaded
Environment configuration loaded
2023-07-20 15:04:21,775: Welcome to the Primary-level AI Training Environment (PrimAITE) (version: 2.0.0rc1)
2023-07-20 15:04:21,775: The output directory for this session is: C:\Users\czar.echavez\primaite\sessions\2023-07-20\2023-07-20_15-04-21
2023-07-20 15:04:21,779: Beginning learning for 10 episodes @ 256 time steps...
2023-07-20 15:04:22,379: Episode: 1, Average Reward: -0.0020839843750000003
2023-07-20 15:04:23,137: Episode: 2, Average Reward: -0.0021933593750000004
2023-07-20 15:04:23,831: Episode: 3, Average Reward: -0.0022617187500000003
2023-07-20 15:04:24,486: Episode: 4, Average Reward: -0.002373046874999999
2023-07-20 15:04:25,125: Episode: 5, Average Reward: -0.0018066406250000014
2023-07-20 15:04:25,791: Episode: 6, Average Reward: -0.0017597656250000013
2023-07-20 15:04:26,415: Episode: 7, Average Reward: -0.0018437500000000014
2023-07-20 15:04:27,053: Episode: 8, Average Reward: -0.0019101562500000015
2023-07-20 15:04:27,715: Episode: 9, Average Reward: -0.0016777343750000013
2023-07-20 15:04:28,359: Episode: 10, Average Reward: -0.0015976562500000012
2023-07-20 15:04:28,550: Finished learning
2023-07-20 15:04:30,851: Beginning deterministic evaluation for 5 episodes @ 256 time steps...
2023-07-20 15:04:31,243: Episode: 1, Average Reward: -0.0018515625000000014
2023-07-20 15:04:31,663: Episode: 2, Average Reward: -0.0018515625000000014
2023-07-20 15:04:32,112: Episode: 3, Average Reward: -0.0018515625000000014
2023-07-20 15:04:32,505: Episode: 4, Average Reward: -0.0018515625000000014
2023-07-20 15:04:32,904: Episode: 5, Average Reward: -0.0018515625000000014
2023-07-20 15:04:32,998: Finished evaluation

```

Also fixed the xskipped tests, since the double running seems to have caused the issue of rewards not matching.

Added a test that runs the PrimAITE in CLI

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

#1595:
- Fixed the...
2023-07-21 11:32:32 +00:00
Christopher McCarthy
10c8604159 Merged PR 132: #1594 - Managed to get the evaluation of rllib agents working. A test has bee...
## Summary
Managed to get the evaluation of rllib agents working. A test has been added to test_primaite_session.py that now tests the full RLlib agent from end-to-end. I've also updated the tests in here to check that the mean reward per episode plot is created for both too. This will need a bit of a re-design further down the line, but for now, it works. Added a custom exception for RLlib eval only error.

Is this a hack? Yes. Does it work? Yes. we'll make this better later.

## Test process
Both a SB3 and Ray RLlib agent is tested now in the test_primaite_session.py module.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1594
2023-07-21 10:26:31 +00:00
Chris McCarthy
722fe97c84 #1594 - Added docstrings and fixed training type. Added a clean-up of the unpacked agent in eval dir. 2023-07-21 10:33:22 +01:00
Christopher McCarthy
e62bee3052 Merged PR 131: #1639 - Added CHANGELOG.md and backfilled it with v1.1.0 and v1.1.1 release n...
## Summary
Added changelog and backfilled with v1.1.0 and v1.1.0 release notes.

Check what's written against notes in: https://nscuk.sharepoint.com/:f:/r/sites/SSE32-ARCDTrainingEnvironments/Shared%20Documents/General/01%20PrimAITE/01.01%20Releases/Release%20Notes?csf=1&web=1&e=uwPsyM

## Test process
N/A

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

#1639 - Added CHANGELOG.md and backfilled it with v1.1.0 and v1.1.1 release notes.

Related work items: #1639
2023-07-21 09:28:28 +00:00
Czar Echavez
7999eb56a5 #1595: refactor variable name 2023-07-21 09:31:37 +01:00
Czar Echavez
21598fd792 #1595: possibly fixed the tests by fixing the bug 2023-07-21 09:17:38 +01:00
Chris McCarthy
df52236a7d #1594 - Managed to get the evaluation of rllib agents working. A test has been added to test_primaite_session.py that now tests the full RLlib agent from end-to-end. I;ve also updated the tests in here to check that the mean reward per episode plot is created for both too. This will need a bit of a re-design further down the line, but for now, it works. Added a custom exception for RLlib eval only error. 2023-07-20 19:58:48 +01:00
Chris McCarthy
470f52f35e #1639 - Reinstalled pre-commit hook 2023-07-20 18:45:02 +01:00
Chris McCarthy
5475155686 #1639 - Added CHANGELOG.md and backfilled it with v1.1.0 and v1.1.1 release notes. 2023-07-20 17:24:55 +01:00
Brian Kanyora
0a6078df65 Merged PR 129: feature/1641:Update config.rst
## Summary
Changed environment config to training config in config.rst as Chris requested.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

feature/1641:
Changed environment config to training config

Related work items: #1641
2023-07-20 15:50:13 +00:00
Czar Echavez
fabbde9641 #1595:
- Fixed the bug where session gets run twice when loading a session via CLI
- Added a test for the CLI run - xskipped while the bugfix for load session acting odd is tbd
- Fixed a minor bug in PrimAITE session where session_path is overwritten
2023-07-20 16:21:30 +01:00
Brian Kanyora
69e7b23d2c feature/1641:
Changed environment config to training config
2023-07-20 16:09:57 +01:00
Sunil Samra
ae8afbdcdc Merged PR 128: #1640 - Update Sphinx Docs for 2.0.0 Release
## Summary
Added to docs as per @<Christopher McCarthy> changes - added prefixes to command line primaite session and explained primate session default no arguments command.

## Test process
*NA*

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

#1640 - Added --ldc and --tc prefixes and added small note about primaite session default run command

Related work items: #1640
2023-07-20 14:31:05 +00:00
SunilSamra
6ce816f2e1 #1640 - Added --ldc and --tc prefixes and added small note about primaite session default run command 2023-07-20 14:45:55 +01:00
Brian Kanyora
f840e924e3 Merged PR 125: feature/1637:Updating-UML-Diagram
## Summary
Updated the UML diagram using puml to help render a better output of the diagram.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [na] I have written **tests** for any new functionality added with this PR
- [na] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

feature/1637:
Updating the UML diagram using puml instead of mmd.

Related work items: #1637
2023-07-20 13:34:26 +00:00
Christopher McCarthy
3bbc7b8615 Merged PR 126: PrimAITE Benchmarking
## Summary
 - Added full benchmarking script that included plots and a LaTeX report. Ran the v2.0.0rc1 benchmark. Tidied a few other things up.

The code is a bit scrappy. But it's not released code. I will endeavour to tidy it up at a later date.

## Test process
Manually ran the script. This is the final report -> [PrimAITE v2.0.0rc1 Learning Benchmark.pdf](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/126/attachments/PrimAITE%20v2.0.0rc1%20Learning%20Benchmark.pdf)

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1632
2023-07-20 12:58:54 +00:00
Marek Wolan
f6a9063484 Merged PR 127: Add license
## Summary
Add license file and reference it in the pyproject.toml

## Test process
Running `pip-licenses` shows:
```
 pre-commit                     2.20.0       MIT License
 primaite                       2.0.0rc1     MIT
 primaite                       2.0.0rc1     MIT
 prometheus-client              0.17.1       Apache Software License
 prompt-toolkit                 3.0.39       BSD License
```
## Checklist
- [x This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1638
2023-07-20 11:16:05 +00:00
Chris McCarthy
994546046e Create README.md 2023-07-20 10:55:59 +01:00
Brian Kanyora
90620e0c64 feature/1637:
Fixed the relationship between PrimIATE and Transaction.
2023-07-20 10:54:42 +01:00
Chris McCarthy
3731b2ba13 #1632 - Fixed output directory clear bug. Added gputil to dev deps. 2023-07-20 10:28:19 +01:00
Marek Wolan
dd9613853b Apply suggestions from code review 2023-07-20 09:21:35 +00:00
Chris McCarthy
f22681d6b4 #1632 - Added Python version to the System Information section in the report 2023-07-20 10:16:29 +01:00
Chris McCarthy
e9fc9a0d1a #1632 - Increased rolling window from 5 to 25 2023-07-20 10:06:44 +01:00
Marek Wolan
1e24ce7b9a Add license 2023-07-20 10:03:05 +01:00
Chris McCarthy
afbe2e1400 #1632 - Added full benchmarking script that included plots and a LaTeX report. Ran the v2.0.0rc1 benchmark. Tidied a few other things up. 2023-07-20 08:48:18 +01:00
Brian Kanyora
f8959a65e9 feature/1637:
Updating the UML diagram using puml instead of mmd.
2023-07-19 17:03:10 +01:00
Chris McCarthy
ba6f8f054b Merge branch 'dev' into feature/1632_Add_benchmarking_scripts 2023-07-18 13:24:15 +01:00
Christopher McCarthy
52a7185583 Merged PR 124: #1635 - Updated the session outputs details in primaite_session.rst
## Summary
- Updated the session outputs details in primaite_session.rst
- Fixed Logger typehint bugs
- Fixed typing issues in access_control_list.py

## Test process
Build the docs

![image.png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/124/attachments/image.png)

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

#1635 - Updated the session outputs details in primaite_session.rst

Related work items: #1635
2023-07-18 10:44:17 +00:00
Chris McCarthy
d730a206d8 #1635 - Fixed typing issues in access_control_list.py 2023-07-18 11:38:28 +01:00
Chris McCarthy
713aa279ec Merge branch 'dev' into feature/1635_Update_Primaite_Session_page_in_Docs
# Conflicts:
#	src/primaite/setup/old_installation_clean_up.py
#	src/primaite/setup/reset_example_configs.py
2023-07-18 11:36:31 +01:00
Chris McCarthy
e070b247b1 #1635 - Updated the session outputs details in primaite_session.rst
- Fixed Logger typehint bugs
2023-07-18 11:34:41 +01:00
Sunil Samra
a10a1d9267 Merged PR 120: Change Functionality of ACL Rules
## Summary
### ACL List
First change was I changed `access_control_list.py` from a `dict` to a `list` so it is now an ordered structure. This was done so I could implement the positions inside the `ACL` and `ANY` action spaces.

From this, some functions have changed such as `add_rule` and `remove_rule`, `is_blocked` and `get_relevant_rules`.

The ACL list is now a fixed size and on initialisation it is filled with `None` types. When a function calls `self.acl` the `implicit rule` (if there is one) is added after the last `ACLRule` object in the list. The remainder of the list (if there is left over space) is padded out with `None`.

As the agent adds rules, the `None` are replaced by `ACLRule` objects and the agent cannot overwrite an existing `ACLRule` with another, it can only write over `None` types.

### ACL Training Config Changes
Changes have been made to the `training_config_main.yaml`. There are 2 new items:

`implicit_acl_rule:` - Implicit ACL firewall rule at end of list to be default action (ALLOW or DENY)
`max_number_acl_rules:` - Total number of ACL rules allowed in the environment

In the `OBSERVATION_SPACE` area of the config, `ACCESS_CONTROL_LIST` can be selected

They have default values if none are specified so for the older configs - these values are in the `TrainingConfig` dataclass.

### ACL and ANY Action Spaces
I changed the ACL space from length of 6 to 7. I have included the `position` of where the agent wants to position the ACL Rule.

`position` = index in `self.acl` with bounds [0 to ...]

As a result, total possible actions have gone up.

### ACL Observation Space
In the observations.py I have made a new observation component: Access Control List.
It has the following mappings/meanings:

        [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
        [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
        [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
        [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
        [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
        [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)

I created a new 0 meaning, which means NA and represents the None objects in the ACLList.

Also, there is no 'flatten' in the observation space components and this has been done in the observations.py now if there are multiple components.

## Test process
I have written tests in a new `TestAccessControlList` object in `test_observations.py`.

I ran a single test which was 1000 episodes, SB3/PPO, Config 5 and ACL Observation Space. I seemed to get some interesting results which may need investigating on Monday.

![Figure_1.png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/120/attachments/Figure_1.png)

## Checklist
- ...
2023-07-18 10:31:15 +00:00
Marek Wolan
9e3285350a Get tests working with new ACL changes 2023-07-18 11:16:39 +01:00
Marek Wolan
c5f612889e Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules 2023-07-18 10:55:31 +01:00
Marek Wolan
9c6ee73b9e Merged PR 122: Typehint everything
## Summary
Added typehints to functions/methods, and class attributes.

## Test process
I used flake8-annotations and mypy to verify completeness and correctness. Mypy did throw up a very large number of errors and many of them point to some potential problems in the codebase. To elaborate, there are some places where there has been confusion as to whether objects should be strings, integers, or enums. Resolving this is out of scope of this PR but I will create more tickets with concrete examples.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1623
2023-07-18 09:43:08 +00:00
Marek Wolan
a2ef4328dd Remove redundant 'if TYPE_CHECKING' statements 2023-07-18 10:21:06 +01:00
Marek Wolan
393505b98b Ensure everything is still typehinted 2023-07-18 10:13:54 +01:00
Chris McCarthy
e198c17ac0 #1632 - Added bench marking script 2023-07-18 10:11:01 +01:00
Marek Wolan
a7a5fb8598 Mark failing tests as Xfail to force build success 2023-07-18 10:08:02 +01:00
Marek Wolan
3d0e50823a Merge branch 'dev' into feature/1623-typehints 2023-07-18 10:03:48 +01:00
Christopher McCarthy
15f37c938f Merged PR 123: #1631 - Added the DEFCON 703 header to all possible files
## Summary
Added the DEFCON 703 header to all possible files

## Test process
Built docs to confirm that the top-of-the-page comment does not break anything

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

#1631 - Added the DEFCON 703 header to all possible files

Related work items: #1631
2023-07-17 19:51:29 +00:00
SunilSamra
9520cfea24 #901 - Replaced "ALLOW" with RulePermissionType.ALLOW
- Added Explicit ALLOW to test_configs in order for tests to work
- Added typing to access_control_list.py and acl_rule.py
2023-07-17 20:40:00 +01:00
Chris McCarthy
81295a4fc4 #1631 - Updated the copyright statement to comply with DEFCON 703 Edition 08/13 2023-07-17 19:57:34 +01:00
SunilSamra
a2f43b5abc #901 - merged dev into branch 2023-07-17 19:54:07 +01:00
SunilSamra
95b6211781 Merge remote-tracking branch 'origin/feature/901-change-functionality-acl-rules' into feature/901-change-functionality-acl-rules 2023-07-17 19:42:29 +01:00
SunilSamra
3aab6a3738 #901 - ran black pre-commit over observations.py to fix it 2023-07-17 19:42:05 +01:00
Sunil Samra
1721f2eb84 Apply suggestions from code review 2023-07-17 18:36:13 +00:00
Chris McCarthy
2d1a1e6db7 #1631 - Added the DEFCON 703 header to all possible files 2023-07-17 19:28:43 +01:00
Christopher McCarthy
35af1e9d1e Merged PR 121: #1629 - Added rllib test
## Summary
Quick test that uses RLLIB in a session

## Test process
The learning session completes then we check that the number of rows in both the average reward per episode and all transactions csv files.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

#1629 - Added rllib test

Related work items: #1629
2023-07-17 17:28:51 +00:00
Czar Echavez
e1ac628793 Merged PR 119: Loading SB3 Agents + Loading agent via PrimaiteSession
## Summary
- Added a feature which allows a user to load a previous SB3 session
- Added a feature which allows a user to load a previous PrimaiteSession
- Added a feature which allows a user to load a previous session via the CLI: `primaite session --load "<SESSION_PATH>"`
- RLlib is TODO in another ticket #1626
- Parallel tests via the [pytest-xdist](https://pypi.org/project/pytest-xdist/) dependency (MIT licensed)
- Moved hardcoded agent into hardcoded_abc.py
- renamed agent.py to agent_abc.py to clarify it is an abstract base class
- Added documentation to clarify how to use the feature via CLI or using the run function via main.py

## Test process
Created [test_session_loading.py](https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/119?_a=files&path=/tests/test_session_loading.py) which loads a previously run session and then performs a learn and evaluation run on the loaded agent/Primate session.

The test copies the saved session into a temporary folder, which is then set as the test session path. Once the test is done, the temporary folder should then be deleted

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1595
2023-07-17 15:26:42 +00:00
Marek Wolan
bfce2f9a7b Change typehints after mypy analysis 2023-07-17 16:22:07 +01:00
SunilSamra
257be9532f #901 - Changed num_eval_steps back to 1 in ppo_seeded_training_config.yaml 2023-07-17 15:54:15 +01:00
SunilSamra
57157db08c Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules 2023-07-17 15:51:01 +01:00
Czar Echavez
39b30460cd Merge branch 'dev' into feature/1595-finalise-and-test-the-loading-of-trained-agents 2023-07-17 15:23:46 +01:00
Sunil Samra
dd21f9440f Apply suggestions from code review 2023-07-17 14:21:37 +00:00
Brian Kanyora
a432822bcb Merged PR 113: feature/1597-Getting-Started
## Summary
Add a Getting started page to the docs file.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [na] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1597
2023-07-17 14:19:16 +00:00
SunilSamra
ec938ce761 #901 - Changed num_eval_steps back to 1 in ppo_seeded_training_config.yaml 2023-07-17 14:06:33 +01:00
SunilSamra
007a0c4b98 Merge remote-tracking branch 'origin/feature/901-change-functionality-acl-rules' into feature/901-change-functionality-acl-rules 2023-07-17 14:01:32 +01:00
SunilSamra
2526427f2f #901 - Fixed bug in implicit rule - comparing it to string ALLOW or DENY in access_control_list.py 2023-07-17 13:58:06 +01:00
Sunil Samra
cc09fe9079 Removed apply_implicit_rule comment 2023-07-17 12:45:31 +00:00
SunilSamra
78d7f39342 #901 - Removed flatten from training configs
- Added flatten operation in observations.py when there are multiple obs components
- Updated config.rst docs
2023-07-17 13:44:16 +01:00
SunilSamra
da20c0e9e6 #901
- Removed bool apply_implicit_rule
- Set default implicit_rule to EXPLICIT DENY
- Added position to ACLs in laydown configs
- Removed apply_implicit_rule from training configs
2023-07-17 13:00:58 +01:00
Chris McCarthy
360eb38c2b #1629 - Added assertion in the test that checks the length of the all transactions file too.
- Added supporting function on the TempPrimaiteSession class that reads the all transactions csv file.
- Some renaming of the functions.
2023-07-17 12:14:47 +01:00
Czar Echavez
6b76214eb2 #1595: set default tc and ldc to None for AgentABC and PrimaiteSession + adding a comment for cli load flag 2023-07-17 11:54:54 +01:00
Chris McCarthy
75c91b9eb9 #1629 - Added rllib test 2023-07-17 11:50:07 +01:00
Marek Wolan
432da5ca90 Add typehint for agent config class 2023-07-17 11:21:29 +01:00
Brian Kanyora
9a0b14b111 Merge remote-tracking branch 'origin/dev' into feature/1597-Getting-Started 2023-07-17 10:33:59 +01:00
SunilSamra
707d8f6189 #901
- Added check in access_control_list.py which sets implicit permission to NA if boolean is False
- Changed the defaults in training_config.py so that each scenario has an EXPLICIT ALLOW rule as default implicit rule
- Updated the test_seeding_and_deterministic_session.py because of change no2 adds an extra rule to that scenario
2023-07-17 10:27:56 +01:00
Brian Kanyora
50697c6f75 Apply suggestions from code review 2023-07-17 09:23:11 +00:00
SunilSamra
9df8d132fc #901 - added to config.rst and added new ACL main config options 2023-07-17 10:08:12 +01:00
Marek Wolan
2bb71623fa Fix types according to mypy 2023-07-14 16:38:55 +01:00
Czar Echavez
ea7c1519fe #1595: minor fix to cli command 2023-07-14 16:04:34 +01:00
SunilSamra
8aa71c3ff8 #901 - amended comment in observations.py 2023-07-14 16:04:13 +01:00
Czar Echavez
7c2ff55da2 #1595: added loading sessions to run command + test + documentation for how to use loading sessions 2023-07-14 15:51:38 +01:00
SunilSamra
1b6244d13f #901 - amended comment in training_config_main.yaml 2023-07-14 15:49:18 +01:00
SunilSamra
661c865108 #901 -
- Added comments in access_control_list.py
- Changed obs_shape to max_number_acl_rules from max_number_acl_rules + 1 as index starts from 1
- Commented episode and step print line from test_single_action_space.py
2023-07-14 15:27:37 +01:00
SunilSamra
eb75d15722 901 - Added another test and tidied up comments in test_observation_space.py and tidied up comments in observations.py 2023-07-14 14:51:26 +01:00
Marek Wolan
31fedb945e Add typehints 2023-07-14 14:43:47 +01:00
SunilSamra
4e53564670 901 - Changed the default expected_mean_reward_per_episode values in test_seeding_and_deterministic_session.py 2023-07-14 14:26:10 +01:00
Czar.Echavez
8e2f105d57 #1595:
- Added ability to load sessions via PrimaiteSession
- PrimaiteSession loading test
- Added a NotImplemented RLlib loading for now
- Added the ability to load sessions for hardcoded agents
- Moved Session metadata parsing to utils
2023-07-14 14:14:03 +01:00
Chris McCarthy
f9c7cafe87 #901 - Dropped temp_primaite_sessiion_2 from conftest.py.
- Re-added the hard-coded mean rewards per episode values from a rpe-trained agent to the deterministic test in test_seeding_and_deterministic_session.py
- Partially tidies up some tests in test_observation_space.py; Still some work to be done on this at a later date.
2023-07-14 14:13:11 +01:00
SunilSamra
e743b2380c 901 - fixed test_observation_space.py, added test fixture for test_seeding_and_deterministic_session.py and increased default max number of acls 2023-07-14 12:29:50 +01:00
Marek Wolan
c2931bde6c Added type hints 2023-07-14 12:01:38 +01:00
Czar Echavez
436448beed #1595: fix poorly merged tests + files 2023-07-14 11:21:59 +01:00
Czar Echavez
7b929109dc #1595: test to make sure that the loaded agent trains + remove unnecessary files + fixing agent save output name 2023-07-14 10:56:28 +01:00
Czar Echavez
118b05ede0 Merge branch 'dev' into feature/1595-finalise-and-test-the-loading-of-trained-agents 2023-07-14 08:39:52 +01:00
Marek Wolan
9650669c83 Add More Typehint 2023-07-13 18:08:44 +01:00
SunilSamra
558223e8b6 901 - removed print statements and merged with dev 2023-07-13 17:14:59 +01:00
SunilSamra
77f717c649 Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules 2023-07-13 16:48:02 +01:00
Czar.Echavez
738e5b5dca #1595 missed hardcoded_abc file in commit 2023-07-13 16:24:30 +01:00
Czar.Echavez
606354614a #1595:
- SB3 Agent loading
- rename agent.py -> agent_abc.py
- rename hardcoded.py -> hardcoded_abc.py
- Tests
- Added in test asset that is used to load the SB3 Agent
2023-07-13 16:24:03 +01:00
Marek Wolan
36e48dc8e9 Continue Adding Typehints 2023-07-13 12:25:54 +01:00
SunilSamra
0ab4dab72a 901 - fixed test_single_action_space.py test 2023-07-13 11:45:23 +01:00
SunilSamra
f8cb18c654 901 - changed acl current obs from list to numpy.array, changed default ACL list in training_config.py to FALSE, and tried to make test_seeding_and_deterministic_session.py test without fixed reward results 2023-07-13 11:04:11 +01:00
Brian Kanyora
fd2ab39edf feature/1597:
Added dependencies to the index.rst since v1.1.0.
2023-07-13 09:36:04 +01:00
Brian Kanyora
f5e1ef7491 feature/1597:
Added dependencies to the index.rst since v1.1.0.
2023-07-13 09:35:42 +01:00
Marek Wolan
f4a70394e0 Type hint ACLs 2023-07-12 16:58:12 +01:00
Marek Wolan
c61770825a Merged PR 116: Update documentation
## Summary
* Update observation space information and standardise formatting of code blocks in that section
* Remove non-ascii quotation characters
* Update custom blue agent page to match new AgentSession classes.
* Introduce glossary
* Provide a first draft of migration guide for 1.2 to 2.0 (probably not comprehensive)

## Test process
Sphinx is able to build the documentation as checked on my local machine

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1602
2023-07-12 14:28:41 +00:00
Sunil Samra
a19fbd1e98 Merged PR 115: Configure Different Episode and Step Counts for Training and Evaluation
## Summary
Training configs now have 2 different types of episode and step counts - one for train and one for evaluation.

`num_train_episodes`
`num_train_steps`
`num_eval_episodes`
`num_eval_steps`

## Test process
A test file `test_train_eval_episode_steps.py` has been implemented which runs train and evaluation session on two particular configs.

The train and evaluation sessions have different episodes and step count and the test checks that the output log files have the correct number of `total_steps` and `total_episodes`.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1566, #1589
2023-07-12 13:34:58 +00:00
Czar Echavez
85c360548b #1595: run tests in parallel 2023-07-12 12:04:26 +01:00
SunilSamra
06c20f6984 Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules
# Conflicts:
#	src/primaite/acl/access_control_list.py
2023-07-12 10:45:03 +01:00
SunilSamra
96b48aad79 1566 - removed redundant config file 2023-07-12 09:52:54 +01:00
SunilSamra
f817efdc69 901 - fixed how acls are added into list with new logic - agent cannot overwrite another acl in the list 2023-07-12 09:47:16 +01:00
Marek Wolan
c7547f715e Add better hyperlinks 2023-07-12 09:16:40 +01:00
SunilSamra
d4469f5226 Merge remote-tracking branch 'origin/dev' into feature/1566-configure_episode-steps-learn-eval 2023-07-11 17:18:05 +01:00
Chris McCarthy
3c20764096 #1597 - Fixed Project Links side bar 2023-07-11 15:50:37 +01:00
Chris McCarthy
11defda955 Merge remote-tracking branch 'origin/feature/1597-Getting-Started' into feature/1597-Getting-Started
# Conflicts:
#	docs/source/getting_started.rst
2023-07-11 15:47:47 +01:00
Chris McCarthy
5b3663c3cf #1597 - Added code tabs to getting started page 2023-07-11 15:47:13 +01:00
Czar.Echavez
baa14b6cd7 #1595: Moved hardcoded agent into its own file 2023-07-11 15:03:02 +01:00
Brian Kanyora
79724d6884 Added a space 2023-07-11 14:15:02 +01:00
Brian Kanyora
30d8478a78 Addressing Sunils comments from my 2023-07-11 13:49:01 +01:00
Marek Wolan
0ec2f79ac3 Merge remote-tracking branch 'origin/dev' into feature/1602-update-docs 2023-07-11 13:13:02 +01:00
Marek Wolan
0c63d197e5 Merged PR 114: Change build pipeline to only run once on commits
## Summary
Unfortunately, I had to do away with the nice and neat matrix strategy for builds, because they do not support conditionals. Instead, I manually replicated the behaviour of the matrix but added a conditional to run every platform only when the 'build reason' is PR.

## Test process
*How have you tested this (if applicable)?*

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1603
2023-07-11 12:12:32 +00:00
SunilSamra
585d35338f 1566 - updated docs for new items in training_config 2023-07-11 12:40:25 +01:00
SunilSamra
f3750032be 1566 - applied pre-commit 2023-07-11 12:37:14 +01:00
SunilSamra
350b3db3f6 901 - changed implicit_acl_rule from str to enum name 2023-07-11 12:36:22 +01:00
Marek Wolan
6a888d2efe Updated migration guide 2023-07-11 12:10:20 +01:00
Marek Wolan
5f6bc32b98 Added draft migration guide. 2023-07-11 12:01:48 +01:00
SunilSamra
6b59ce960d Merge remote-tracking branch 'origin/dev' into feature/1566-configure_episode-steps-learn-eval
# Conflicts:
#	src/primaite/config/training_config.py
2023-07-11 11:39:21 +01:00
Marek Wolan
9e936513d5 Improved order of glossary terms 2023-07-11 11:31:29 +01:00
Marek Wolan
dc26863216 Completed glossary 2023-07-11 11:13:28 +01:00
Marek Wolan
56fd9c4d0a Merge remote-tracking branch 'origin/dev' into feature/1602-update-docs 2023-07-11 10:12:40 +01:00
Marek Wolan
1633900ce7 Fix typo in Build.Reason 2023-07-11 09:01:43 +00:00
Marek Wolan
6c7ec62166 Fixed formatting with pre-commit 2023-07-11 09:57:27 +01:00
Marek Wolan
a07ce00852 Added glossary 2023-07-11 09:56:52 +01:00
Marek Wolan
dcf5bfddfa Fix syntax 2023-07-11 08:54:22 +00:00
Marek Wolan
a303e9096a Changed structure of build pipeline yaml 2023-07-11 08:53:37 +00:00
Marek Wolan
81a8058836 Change parameter matrix to list instead of dict 2023-07-11 08:22:30 +00:00
Marek Wolan
c641f67914 Capitalisation error in value 2023-07-11 08:15:16 +00:00
Marek Wolan
7f64d06ad4 Fix indent 2023-07-11 08:14:34 +00:00
Marek Wolan
c8191e60ba Typo in word only 2023-07-11 08:14:08 +00:00
Marek Wolan
d555584e90 Potentially fix syntax error 2023-07-11 08:08:29 +00:00
Marek Wolan
548ecf8e08 Edit pipeline to use runtime parameters
https://stackoverflow.com/a/70046417
2023-07-11 08:05:38 +00:00
Marek Wolan
d8cfbc1042 Updated azure-ci-build-pipeline.yaml 2023-07-11 07:19:58 +00:00
Marek Wolan
831469d01c Built matrix conditionally 2023-07-11 07:16:11 +00:00
Marek Wolan
19a9cef130 Merged PR 105: Fix errors while trying to run Hardcoded agent
## Summary
Since we added File System State as a new part of the observation space, some of the assumptions made by imported ADSP code were not met. This is addressed by these changes.

The code no longer crashes, but the hardcoded ACL agent doesn't work very well, it keeps returning action 0 and receives a low reward. Also if there are ACL rules with 'ANY' as a source IP, it crashes the function `get_node_of_ip` within the HardCodedACLAgent._calculate_action_full_view() method.

I'm not sure how much effort we need to spend fixing the hardcoded agents as they don't seem like they were delivered in a finished state.

## Test process
Can confirm the hardcoded agent can run within a primaite session now.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

## note
I would appreciate some input about what we should do with hardcoded agents for release 2.0.0, it may require significant effort to get them working correctly.

Related work items: #1587
2023-07-10 15:10:12 +00:00
Brian Kanyora
5939fda2ba Merge remote-tracking branch 'origin/dev' into feature/1597-Getting-Started
# Conflicts:
#	docs/source/primaite-dependencies.rst
2023-07-10 15:48:32 +01:00
Brian Kanyora
ecc06a5db0 feature/1597:
pre commit fix
2023-07-10 15:41:50 +01:00
Marek Wolan
30bcdba429 Finished writing custom agent example. 2023-07-10 14:56:06 +01:00
SunilSamra
563ff72fd6 1566 - fixed the test_training_config.py test file by removing num_steps from init 2023-07-10 13:24:34 +01:00
Marek Wolan
ca737e080f Changed build pipeline experimentally. 2023-07-10 10:25:26 +00:00
SunilSamra
921dc934c2 1566 - added correct num_train_episodes etc values to configs, fixed test_reward.py 2023-07-10 11:25:26 +01:00
Marek Wolan
5ec8d3c8c1 Merged PR 110: Update Observation spaces description
## Summary
This minor update adds more detail and links to relevant pages within the API docs.

## Test process
Locally built docs in HTML format to verify all content displays correctly.

Related work items: #1596
2023-07-10 10:20:42 +00:00
Marek Wolan
43a4f93626 Changed order of text in custom agent docs 2023-07-10 11:19:47 +01:00
Brian Kanyora
3c9b8a272a Additional syntax changes 2023-07-10 10:28:27 +01:00
Brian Kanyora
e3ad1470df feature/1597:
Small syntax changes
2023-07-10 10:01:25 +01:00
Marek Wolan
bd6f9fc309 Merge remote-tracking branch 'origin/dev' into bugfix/1587-hardcoded-agent 2023-07-10 09:15:25 +01:00
Marek Wolan
47d7e9f3f6 Merged PR 104: Fix formatting in docstrings
## Summary
Fixes some incorrectly formatted documentations, such as in the observation module. Also adds some missing module-level docstrings. Also adds a PrimAITE Favicon to docs.

Removed Primaite-dependencies.rst as it's autogenerated.

## Test process
Purely cosmetic, so functionality not tested. I did render the HTML output to observe that some mistakes have been fixed.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [na] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

Related work items: #1572
2023-07-10 08:14:08 +00:00
Marek Wolan
0145532103 Update docs 2023-07-09 20:23:53 +01:00
Marek Wolan
91287f8666 Merge remote-tracking branch 'origin/dev' into feature/1572-fix-docs-formatting 2023-07-09 18:13:57 +01:00
Marek Wolan
605a5b4cd6 Merge remote-tracking branch 'origin/dev' into bugfix/1587-hardcoded-agent 2023-07-09 18:07:30 +01:00
Marek Wolan
17894376c6 Removed comment 2023-07-09 18:07:21 +01:00
Marek Wolan
9d49406df6 Merge remote-tracking branch 'origin/dev' into feature/1596-better-observation-docs 2023-07-09 18:05:13 +01:00
Brian Kanyora
23adc740cd Resolved more syntax errors 2023-07-07 16:32:35 +01:00
SunilSamra
41fab6562e 1566 - updated configs to correct values of step count and number of episodes 2023-07-07 16:26:12 +01:00
Brian Kanyora
752a611b89 Fixed the rst syntax 2023-07-07 16:25:55 +01:00
Marek Wolan
677d12b550 Merged PR 106: Resolve TODOs about documenting functions
## Summary
- Added type hints and docstrings to functions imported from ADSP.
- Imported `get_relevant_rules` which was referenced but didn't exist.
- Removed duplicated function definitions in `agents.utils`

## Test process
The changes in this PR are almost exclusively cosmetic. I can confirm that after adding/removing functions, the unit tests passed fine. I was also able to run the Hardcoded node and ACL agents without problems.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [na] I have written **tests** for any new functionality added with this PR
- [na] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1575
2023-07-07 15:10:44 +00:00
Chris McCarthy
40381833d3 #1566 - Refactored the test_train_eval_episode_steps.py to sue TempPrimaiteSession.
- Fixed all errors that were caused b fixing the above.
- Some tests still fail, these are for SS to fix.
- Dropped the old run_generic stuff from conftest.py
2023-07-07 15:50:14 +01:00
SunilSamra
35b481a2f3 Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules 2023-07-07 15:14:05 +01:00
Chris McCarthy
d49f73f139 Merge remote-tracking branch 'origin/dev' into 1566-configure-episode-steps-learn-eval
# Conflicts:
#	src/primaite/agents/rllib.py
2023-07-07 14:34:20 +01:00
Marek Wolan
d7bf678b1f Reworded observation description 2023-07-07 14:24:37 +01:00
SunilSamra
e03c29b921 1566 - added test file and edited configs to include types of num steps and modifed agents to use correct step and episode counts 2023-07-07 14:13:47 +01:00
Marek Wolan
bbb305d561 Update observation space documentation 2023-07-07 13:52:14 +01:00
Brian Kanyora
4ef7831bfa Added a getting started file 2023-07-07 11:37:57 +01:00
Marek Wolan
7e0eee5d73 Merge remote-tracking branch 'origin/dev' into feature/1572-fix-docs-formatting 2023-07-07 10:30:11 +01:00
Marek Wolan
f4b98542b6 Standardise docstring summary line placement. 2023-07-07 10:28:00 +01:00
Czar Echavez
036e0fe342 Merged PR 89: #1386 Enable a repeatable/deterministic baseline test
## Summary
- Added the fix from #1535 with minor changes to make sure that the `primaite_env.step()` function can properly parse the action
- added the config deterministic and seed to training config
- added the deterministic and seed to the Training config class, with defaults `False` and `None` respectively
- minor fix to `primaite_env.close()` function so that it now works

## Test process
Added e2e tests for generic, ppo and a2c which evaluates a trained agent twice to make sure that the seeding and deterministic action works

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

#1386: added the ability to set deterministic and seeding RNG when training and evaluating + the fix provided in #1535

Related work items: #1386, #1535
2023-07-07 09:22:47 +00:00
Czar Echavez
04e52453b1 Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test 2023-07-06 22:22:37 +01:00
Christopher McCarthy
207601b81f Merged PR 109: Auto save agent at end of training
## Summary
* Made RLlib and SB3 agents save at the end of each learning session by default using a common file naming format. Also now agents only checkpoint every n and not on the final episode.

## Test process
*Tests saved agent file in the test_primaite_session test.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1593
2023-07-06 16:29:48 +00:00
Marek Wolan
3a75ed8ccc Merged PR 108: Divide default rewards by 10000
## Summary
As per the discussion this morning, this PR reimplements changes that were made by ADSP to make the default rewards smaller. This also adds type hints rewards as floats.

## Test process
I checked that sessions are able to run and that they report values similar to what we are used to but smaller by a factor of 10000. I did not change the reward values in the integration test configs, and the tests still pass.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #889, #1586
2023-07-06 15:17:47 +00:00
Marek Wolan
86725064ec Added docstrings to class intialisers 2023-07-06 16:08:51 +01:00
Marek Wolan
2a08d3a2a5 Removed reference to file that no longer exists 2023-07-06 15:18:49 +01:00
Marek Wolan
82a5122276 Add __init__ to class special members doc 2023-07-06 15:18:33 +01:00
Marek Wolan
4c03aaee24 undeleted api (lol) 2023-07-06 15:05:39 +01:00
Marek Wolan
1ade92f55c Deleted icon 2023-07-06 15:04:46 +01:00
Chris McCarthy
c9f4741655 #1593 - Ran pre-commit hook 2023-07-06 14:18:49 +01:00
Chris McCarthy
82d7c168fe #1593 - Check that agent saved file exists 2023-07-06 14:13:02 +01:00
Chris McCarthy
159d47fd6c #1963 - Made RLlib and SB3 agents save at the end of each learning session by default using a common file naming format. Also now agents only checkpoint every n and not on the final episode 2023-07-06 13:56:12 +01:00
Czar Echavez
46b44f9e23 #1386: remove redundant config files + test fixtures + fixing deterministic and seed config description in documentation to avoid misunderstandings 2023-07-06 13:27:44 +01:00
Marek Wolan
3b91a99070 Updated rewards type description in docs 2023-07-06 12:56:24 +01:00
Marek Wolan
c5d7d55747 Change reward to float and divide by 10000 2023-07-06 12:52:14 +01:00
Czar Echavez
99f1f7cfc1 #1386: remove setting of global seed + running pre-commit checks 2023-07-06 12:10:26 +01:00
Chris McCarthy
3438ce7e09 #1386 - Updated tests in test_seeding_and_deterministic_session.py to use TempPrimaiteSession.
- Added test_seeded_learning test and test_deterministic_evaluation test.
- Passed config values seed and deterministic to ppo agent
- Dropped deterministic override in evaluate functions
- TempPrimaiteSession now writes files to a UUID folder rather than datetime
- Added seed to Ray RLlib agent setup in rllib.py
- Added seed to SB3 agent setup in sb3.py
2023-07-06 11:35:44 +01:00
SunilSamra
4371ca13fc 1566 - added train_episodes, train_steps, eval_episodes and eval_steps to training_config_main.yaml 2023-07-06 11:12:51 +01:00
SunilSamra
f651937759 901 - changed how acl rules are added to access control list and added structure to AccessControlList observation 2023-07-06 11:07:21 +01:00
Marek Wolan
e174db5d9e Rescaled default rewards by a factor of 1/10000 2023-07-06 10:51:34 +01:00
Marek Wolan
87bdaa1ec3 Updated documentation 2023-07-06 10:34:27 +01:00
Marek Wolan
c38dda34b9 Removed duplicated function definitions 2023-07-06 10:23:14 +01:00
Chris McCarthy
8faf9d70a0 temp 2023-07-06 10:07:54 +01:00
Marek Wolan
b426d5802e Updated docstrings 2023-07-05 16:46:23 +01:00
Marek Wolan
5c167293e3 Add docstrings and type hints. 2023-07-05 16:19:43 +01:00
Marek Wolan
0ae7158859 Merge branch 'bugfix/1587-hardcoded-agent' into feature/1575-docstring-param-desc 2023-07-05 15:22:13 +01:00
Czar Echavez
713225b432 #1386: remove unneeded configs + setting the seed globally + temp test 2023-07-05 15:02:41 +01:00
Marek Wolan
7482aead76 typo 2023-07-05 14:50:03 +01:00
Marek Wolan
f62b2aef1c Fix minor typos in docstrings 2023-07-05 14:13:43 +01:00
Marek Wolan
171b5cb58e Imported ADSP function for ACL 2023-07-05 14:10:52 +01:00
Marek Wolan
b3d4eb4ec0 Changed hardcoded agent helper for new obs space 2023-07-05 13:58:46 +01:00
Czar Echavez
075b11aeca #1386: fix saving of agent 2023-07-05 11:41:18 +01:00
SunilSamra
f121b0e21c 901 - merged with dev 2023-07-05 11:34:15 +01:00
Marek Wolan
940f37bfc6 Merge branch 'feature/1572-fix-docs-formatting' of https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE into feature/1572-fix-docs-formatting 2023-07-05 10:14:20 +01:00
Marek Wolan
38a3666e8e Move class docstrings out of init function. 2023-07-05 10:14:16 +01:00
Marek Wolan
ea01e2209b Updated access_control_list.py 2023-07-05 09:00:41 +00:00
Marek Wolan
3ced1a1913 Update some param descriptions for hardcoded agent 2023-07-05 09:54:50 +01:00
Marek Wolan
cda9819e72 Add blank lines at the end of file. 2023-07-05 09:22:49 +01:00
Marek Wolan
eac79e0941 Add missing module level docstrings. 2023-07-05 09:19:58 +01:00
SunilSamra
3f440c0a28 901 - updated observations.py to change and add new mapping of ACL rules to represent no rule present in list 2023-07-05 09:08:03 +01:00
Czar Echavez
9001510fe7 #1386: fix bug with agent zip file not being saved after run 2023-07-04 16:30:31 +01:00
Marek Wolan
0756e61e5d add module level docstrings 2023-07-04 13:11:06 +01:00
Marek Wolan
7bdcee5c46 remove primaite dependencies as it's autogenerated 2023-07-04 11:57:10 +01:00
Marek Wolan
d41e2ad590 Resolve remaining build warnings for docs 2023-07-04 11:34:36 +01:00
Marek Wolan
5e270c7673 Format docstrings 2023-07-04 11:11:52 +01:00
Marek Wolan
3de6208915 fix formatting on Observation docs 2023-07-04 10:57:00 +01:00
Marek Wolan
3abe39aa10 Add Favicon 2023-07-04 10:55:07 +01:00
Czar Echavez
410afc1d40 Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test 2023-07-04 09:41:07 +01:00
Christopher McCarthy
e199dc52c0 Merged PR 101: Integrate ADSP RLlib and use PrimaiteSession for running between agent frameworks
## Summary
* Brought over the RLlib, hardcoded agents, and simple agents from ADSP 1.1.0. This opened a can of worms... ADSP got their stuff working in notebooks (***_stares at data scientists!_** 😂) but hadn't integrated it into the PrimAITE package or made the other PrimAITE functionality work with it.
* RLlib agents have been fully integrated with the wider PrimAITE package. This was done by:
  * The creation of an `AgentSessionABC` and `HardCodedAgentSessionABC` classes.
  * `SB3Agent` and `RLlibAgent` classes then inherited from `AgentSessionABC`.
  * The ADSP hardcoded agents were integrated into subclasses of `HardCodedAgentSessionABC`.
  * The random and dummy agents were also integrated into subclasses of `HardCodedagentSessionABC`.
  * A set of session output directories were created and managed by the agent session to enable consistent storage of session outputs in a common format regardless of the agent type.
  * The main config was rafactored so that it had
    * **agent_framework** - To identify whether SB3, RLlib, or Custom.
    * **agent_identifier** - To identify whether PPO, A2C, hardcoded, random, or dummy.
    *  **deep_learning_framework** - To identify which framework to use for RLlib.
* Transactions have been overhauled to simplify the process. It also means that they're written in real time so they're not lost if the agent crashes.
* Tests completely overhauled to use `PrimaiteSession`, or at least a test subclass, `TempPrimaiteSession`. It's temp because it uses temp directory rather than main primaite session directory, and it cleans up after itself.
* All the crap removed from `main.py` and made it so that it just runs `PrimaiteSession`.

Now this is where I went off on a tangent...
* CLI added to just make my life and everyone else's life easier.
* Primaite app config added to hold things like logging format, levels etc.
* A `primaite.data_viz.session_plots` module added so that the average reward per episode for each session is plotted and saves for each session (this helped while we were testing and bug fixing).

## Test process
* All tests use `TempPrimaiteSession`, which uses `PrimaiteSession`.
* I still need to write a tests that runs the RLlib, hardcoded, and random/dummy agents. I'll do that now while this is being reviewed.

## Still to do
* Update docs. I'm getting this PR up now so we can get it in to make use of the features. I'll get the docs updated today either on this branch or another branch (depending on how long this review takes).

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #917, #1563
2023-07-04 08:08:31 +00:00
Chris McCarthy
34b294f89a #917 - Reinstalled the pre-commit hook 2023-07-03 20:40:38 +01:00
Chris McCarthy
410d5abe12 #917 - Synced with dev and integrated the new observation space 2023-07-03 20:36:21 +01:00
Chris McCarthy
820f436f8e Merge remote-tracking branch 'origin/dev' into feature/917_Integrate_with_RLLib
# Conflicts:
#	src/primaite/config/_package_data/training/training_config_main.yaml
#	src/primaite/environment/primaite_env.py
#	src/primaite/main.py
#	src/primaite/transactions/transaction.py
#	src/primaite/transactions/transactions_to_file.py
2023-07-03 19:51:52 +01:00
Chris McCarthy
7816e94f83 #917 - Synced with dev (at the point of random red agent) 2023-07-03 17:25:21 +01:00
Chris McCarthy
dffa612ec8 Merge remote-tracking branch 'origin/feature/917_Integrate_with_RLLib' into feature/917_Integrate_with_RLLib 2023-07-03 17:12:03 +01:00
Marek Wolan
4b5cf12aa3 Merged PR 103: Change build pipeline to enable installing from wheel on windows
## Summary
Just splits the install primaite step into two depending if agent is using windows or not.

## Test process
Ran a build successfully.

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style
2023-07-03 16:10:09 +00:00
Marek Wolan
7ddedfcc57 Updated azure-ci-build-pipeline.yaml 2023-07-03 16:02:59 +00:00
Czar Echavez
a883e45bbf Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test 2023-07-03 16:56:44 +01:00
Marek Wolan
8ab936fcdc Merged PR 100: Flatten observation spaces and improve transactions for observations
## Summary
*Replace this text with an explanation of what the changes are and how you implemented them. Can this impact any other parts of the codebase that we should keep in mind?*

## Test process
I ran some training sessions to ensure that the outputted transaction list has the correct data and headers. I was also able to verify that the agent is able to train with observation spaces containing multiple components.

I trained an agent on laydown 3 with NODE_LINK_TABLE both as normal and flattened spaces and the agent learned in both instances.
![image.png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/100/attachments/image.png)  ![image (2).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/100/attachments/image%20%282%29.png)
## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1558
2023-07-03 15:54:00 +00:00
Chris McCarthy
d2764d53cc Merge remote-tracking branch 'origin/dev' into feature/917_Integrate_with_RLLib
# Conflicts:
#	src/primaite/config/_package_data/training/training_config_main.yaml
#	src/primaite/environment/primaite_env.py
2023-07-03 15:07:09 +01:00
Marek Wolan
12c18adeb1 Merge remote-tracking branch 'origin/dev' into feature/1558-flatten-spaces 2023-07-03 15:03:10 +01:00
Marek Wolan
178bd4dc7f Merge branch 'dev' into feature/1558-flatten-spaces 2023-07-03 15:01:56 +01:00
Marek Wolan
f47dd8bf61 Updated azure-ci-build-pipeline.yaml 2023-07-03 13:36:33 +00:00
Czar Echavez
dc4c2c8854 Merged PR 102: 1522 Red Agent random behaviour
## Summary
Ported over ADSP changes regarding the randomised red agent.
Red agent currently only works on laydown configs which contain links.

Each episode generates random red agent instructions

## Test process
Written a test that ensures that the random red agent produces random red agent instructions

| Random red agent | Laydown                | Agent Identifier | Run 1                                                                              | Run 2                                                                              | Run 3                                                                              |
|------------------|------------------------|------------------|------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|------------------------------------------------------------------------------------|
| NONE             | Very Basic (Laydown 3) | A2C              | ![image (4).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%284%29.png)  | ![image (8).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%288%29.png)  | ![image (9).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%289%29.png)  |
| RANDOM           | Very Basic (Laydown 3) | A2C              | ![image (5).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%285%29.png)  | ![image (6).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%286%29.png)  | ![image (7).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%287%29.png)  |
| NONE             | Very Basic (Laydown 3) | PPO              | ![image (10).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%2810%29.png) | ![image (11).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%2811%29.png) | ![image (12).png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/102/attachments/image%20%2812%29.png) |
| RANDOM           | Very Basic (Laydown 3) | PPO              ...
2023-07-03 13:09:32 +00:00
Marek Wolan
8101f49a21 Updated azure-ci-build-pipeline.yaml 2023-07-03 12:44:01 +00:00
Marek Wolan
63a4c1119b Updated azure-ci-build-pipeline.yaml 2023-07-03 12:40:02 +00:00
Marek Wolan
94ca28a85f Add windows build option 2023-07-03 12:37:08 +00:00
Czar Echavez
cb9d40579f #1522: create_random_red_agent -> _create_random_red_agent + converting NodeStateInstructionRed into a dataclass 2023-07-03 13:36:14 +01:00
Czar Echavez
0943e9511b #1522: refactor red_agent_identifier -> random_red_agent so that it is a boolean + documentation 2023-07-03 12:18:58 +01:00
Chris McCarthy
c3ec33e4df #917 - Added Windows and MacOS to build pipeline. Updated so that runs only Python 3.8 and 3.10 (middle version not required) 2023-07-03 12:03:36 +01:00
Chris McCarthy
123ec8343c Merge remote-tracking branch 'origin/dev' into feature/917_Integrate_with_RLLib
# Conflicts:
#	tests/test_reward.py
2023-07-03 11:59:16 +01:00
Christopher McCarthy
c38c13b829 Apply suggestions from code review 2023-07-03 10:47:26 +00:00
Czar Echavez
6c4a538b41 #1522: run pre-commit 2023-07-03 10:08:25 +01:00
Czar Echavez
ae56827bae Merge branch 'dev' into feature/1522-Random-Red-Agent-Behaviour 2023-07-03 09:59:25 +01:00
Czar Echavez
4299170ce4 #1522: added a check for existing links in laydown + test that checks if red agent instructions are random 2023-07-03 09:46:52 +01:00
Sunil Samra
4f0f542570 Merged PR 93: 1555 - updated doc-string to make test understanding easier
## Summary
Changed doc-string of test_reward.py to reflect the new test and what it is trying to do rather than the old outdated one.

## Test process
NA - no logic changes

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

1555 - updated doc-string to make test understanding easier

Related work items: #1555, #1556
2023-07-03 08:10:17 +00:00
Marek Wolan
ee94993344 Apply suggestions from code review 2023-07-03 08:00:51 +00:00
SunilSamra
41aed12f27 901 - merged with changes made to dev 2023-07-03 08:17:52 +01:00
SunilSamra
ccad245e6f Merge remote-tracking branch 'origin/dev' into feature/1555-update-test-reward-doc-string 2023-07-03 08:10:28 +01:00
Chris McCarthy
16534237e0 #917 - Dropped VerboseLevel in enums.py and changed OutputVerboseLevel to SB3OutputVerboseLevel 2023-06-30 17:09:50 +01:00
Chris McCarthy
27ca53878a #917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
2023-06-30 16:52:57 +01:00
Marek Wolan
605ff98a24 Fix flattening when there are no components. 2023-06-30 15:43:15 +01:00
Marek Wolan
975ebd6de2 revert unnecessary changes. 2023-06-30 13:16:30 +01:00
Chris McCarthy
203cc98494 #917 - Fixed primaite_config.yaml issue in cli.py
- Added kaleido to deps in pyproject.toml
2023-06-30 11:40:26 +01:00
Marek Wolan
32d5889b11 Update docs 2023-06-30 10:44:04 +01:00
Marek Wolan
2a8d28cba6 Remove redundant cols from transactions 2023-06-30 10:41:56 +01:00
Czar Echavez
3e691b4f46 #1522: remove numpy randomisation + added random red agent config 2023-06-30 10:37:23 +01:00
Chris McCarthy
d5402cdce8 #917 - Added tensorflow to main deps for RLlib.
- Dropped support for Python 3.11 due to not supported on Ray RLlib.
- Made release pipeline only run once as we're now no longer using pure path wheels.
2023-06-30 10:24:59 +01:00
Marek Wolan
c3c4512544 Remove temporary file 2023-06-30 09:54:34 +01:00
Chris McCarthy
73015802ec #917 - Integrated the PrimaiteSession into all tests.
- Ran a full pre-commit hook and thus encountered tons of fixes required
2023-06-30 09:08:13 +01:00
Marek Wolan
c77fde3dd3 Fix observation representation in transactions 2023-06-29 15:26:07 +01:00
Czar Echavez
f61d50a96f #1522: fixing create random red agent function 2023-06-29 15:03:11 +01:00
Czar Echavez
a2e02c3cfd Merge branch 'dev' into feature/1522-Random-Red-Agent-Behaviour 2023-06-29 14:17:41 +01:00
Chris McCarthy
7f912df383 #917 - Began the process of reloading existing agents into the session 2023-06-28 19:54:00 +01:00
Chris McCarthy
1d3778f400 #917 - Overhauled transaction and mean reward writing.
- Separated out learning outputs from evaluation outputs
2023-06-28 16:34:00 +01:00
Chris McCarthy
7482192046 #917 - Synced with dev and added better logging 2023-06-28 12:01:01 +01:00
Marek Wolan
9666b92caa Attempt to add flat spaces 2023-06-28 11:07:45 +01:00
Chris McCarthy
498e6a7ac1 Merge remote-tracking branch 'origin/dev' into feature/917_Integrate_with_RLLib
# Conflicts:
#	src/primaite/config/training_config.py
#	src/primaite/main.py
2023-06-28 10:11:03 +01:00
Marek Wolan
02f982afa8 Merged PR 95: Apply precommits and add precommit to build pipeline
## Summary
The code changes are purely cosmetic- the result of applying pre-commit to all our files. I also added a pre-commit step to the build pipeline to reject non-conforming PRs

## Test process
I saw that the build pipeline passes with this new step.

## Checklist
- [ ] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #1557
2023-06-28 08:14:49 +00:00
SunilSamra
b8a4ede83f 1555 - added specific steps to doc string 2023-06-27 16:59:43 +01:00
SunilSamra
8a1c0b2db7 Merge remote-tracking branch 'origin/dev' into feature/1555-update-test-reward-doc-string 2023-06-27 16:55:00 +01:00
Marek Wolan
cfeb1c6530 Merged PR 94: Fix ier reward calculation
## Summary
Logic error with negation of booleans.

## Test process
Run with debug logging to verify that no longer getting warnings about reference IERS being blocked.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Fix ier reward calculation

Related work items: #1554
2023-06-27 15:10:19 +00:00
Chris McCarthy
746f878747 Merge remote-tracking branch 'origin/bugfix/1554-fix-not-learning-iers' into feature/917_Integrate_with_RLLib 2023-06-27 15:56:56 +01:00
Marek Wolan
a8c27ec975 Merge branch 'dev' into feature/build-pipeline-precommit 2023-06-27 15:49:49 +01:00
Marek Wolan
cffdcdc0d2 Fix ier reward calculation 2023-06-27 15:27:56 +01:00
Czar Echavez
8f2fd77634 Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test 2023-06-27 14:16:10 +01:00
SunilSamra
301e8b6983 Merge remote-tracking branch 'origin/dev' into feature/1555-update-test-reward-doc-string 2023-06-27 14:09:36 +01:00
Marek Wolan
cf2f9788ec Add pre-commit 2023-06-27 13:07:54 +00:00
Marek Wolan
3adb02118c Merged PR 92: Fix reference IERs
## Summary
As per the ticket and James's explanation, there are now separate reference IERs which are used for the reference environment.

## Test process
I verified that the training can occur.
![image.png](https://dev.azure.com/ma-dev-uk/b50a61ee-86c4-48bc-9a0b-a67645ba12ee/_apis/git/repositories/2825053e-bd3b-45b2-8680-1281809eefa2/pullRequests/92/attachments/image.png)

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [n/a] I have written **tests** for any new functionality added with this PR
- [n/a] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Fix reference IERs

Related work items: #1554
2023-06-27 13:07:39 +00:00
Marek Wolan
185dbb7f02 Add pre-commits to build pipeline 2023-06-27 12:07:33 +00:00
Marek Wolan
be7d0e1745 Cosmetic changes to satisfy pre-commit 2023-06-27 13:06:10 +01:00
Marek Wolan
0bff2d2f36 Improve readability 2023-06-27 12:56:15 +01:00
Marek Wolan
79ecb8e0b9 More descriptive debug msg 2023-06-27 12:44:42 +01:00
SunilSamra
09412cb43d 1555 - updated doc-string to make test understanding easier 2023-06-27 12:27:57 +01:00
Marek Wolan
ebc0a28460 rename to prevent confusion 2023-06-27 10:45:45 +00:00
SunilSamra
ef4d2c6cdd 901 - fixed test_single_action_space.py to reflect new acl structure and added new acl_implicit_rule class attribute 2023-06-27 11:43:33 +01:00
Marek Wolan
e2d6abf833 apply pre-commits 2023-06-27 11:20:18 +01:00
Marek Wolan
feead2cd44 Fix reference IERs 2023-06-27 11:10:21 +01:00
Czar Echavez
fb50b8becf #1386: Apply suggestions from code review - make seed an optional variable 2023-06-23 07:57:31 +00:00
Brian Kanyora
e0f3d61f65 feature\1522:
Create random red agent behaviour.
2023-06-22 15:34:13 +01:00
Chris McCarthy
7f1c4ce036 #917 - Updated main config 2023-06-22 14:10:38 +01:00
Chris McCarthy
5a6fdf58d4 #917 - Got things working'ish 2023-06-20 22:29:46 +01:00
Chris McCarthy
a2cc4233b5 #917 -Finished integrating all agents to either train (policy agents) or evaluate (hard-coded agents). Still some fixing up to do, tidying up, loading etc. also docs. But this is all now working. 2023-06-20 16:06:55 +01:00
SunilSamra
df42a791c9 901 - changed ACL instantiation and changed acl t private _acl (list not dict) attribute, added laydown_ACL.yaml for testing, fixed encoding of acl rules to integers for obs space, added ACL position to node action space and added generic test where agents adds two ACL rules. 2023-06-20 11:47:20 +01:00
Czar Echavez
1a5bd3af48 #1386: fix README.md 2023-06-20 11:22:29 +01:00
Czar Echavez
db67a829d5 #1386: added documentation + dealing with pre-commit checks 2023-06-20 11:19:05 +01:00
Czar Echavez
0ab4520904 #1386: added the ability to set deterministic and seeding RNG when training and evaluating + the fix provided in #1535 2023-06-20 10:41:30 +01:00
Chris McCarthy
03ae4884e0 #917 - Almost there. All output files being writen for SB3/RLLIB PPO & A2C. Just need to bring in the hardcoded agents then update the testa and docs. 2023-06-19 21:53:25 +01:00
Chris McCarthy
23bafde457 #917 - Integrated both SB3 and RLlib agents into PrimaiteSession 2023-06-19 20:27:08 +01:00
Chris McCarthy
c2c396052f #917 - Got RLlib fully training in PrimAITE. Started integrating the the other agents into the Session class 2023-06-18 22:40:56 +01:00
Chris McCarthy
6849939265 #917 - started working on the Agent abstract classes and sub-classes 2023-06-15 09:48:44 +01:00
SunilSamra
c6a947fbaf 901 - started testing for observation space 2023-06-13 16:23:32 +01:00
SunilSamra
5b59642695 901 - added max_acl_rules, implicit_acl_rule and apply_implicit rule to main_config, changed observations.py for ACLs to match the action space for ACLs, added position of acl rule to ACL action type 2023-06-13 14:51:55 +01:00
SunilSamra
fe102dff6f 901 - fixed test_acl.py tests 2023-06-13 10:01:55 +01:00
SunilSamra
cf64990cff 901 - added changes back to ticket 2023-06-13 09:45:45 +01:00
Chris McCarthy
eb3368edd6 temp commit 2023-06-13 09:42:54 +01:00
SunilSamra
cdd7183d85 901 - merged dev into my branch 2023-06-13 08:54:33 +01:00
Christopher McCarthy
9b0e24c27b Merged PR 81: #915 Packaging & Deployment
## Summary
- Created app dirs and set as constants in the top-level init.
- Renamed _config_values_main to training_config.py and renamed the ConfigValuesMain class to TrainingConfig.
- Moved training_config.py to src/primaite/config/training_config.py
- Renamed all training config yaml file keys to make creating an instance of TrainingConfig easier.
- Moved action_type and num_steps over to the training config.
- Decoupled the training config and lay down config.
- Refactored main.py so that it can be ran from CLI and can take a training config path and a lay down config path.
- Refactored all outputs so that they save to the session dir.
- Added some necessary setup scripts that handle creating app dirs, fronting example config files to the user, fronting demo notebooks to the user, performing clean-up in between installations etc.
- Added functions that attempt to retrieve the file path of users example config files that have been fronted by the primaite setup.
- Added logging config and a getLogger function in the top-level init.
- Refactored all logs entries logged to use a logger using the primaite logging config.
- Added basic typer CLI for doing things like setup, viewing logs, viewing primaite version, running a basic session.
- Updated test to use new features and config structures.
- Made tests log to temp directory
- typer==0.9.0 added to pyproject.toml
- Refactored documentation and included APi docs, dependencies.
- Make files now re-build autosummary and deps file.
- Added typer and platformdirs to deps in pyproject.toml.
- Made root_is_pure = True in setup.py as platform/python specific wheels don't need to be built but the option is there should we need to.

## Test process
- Added an e2e test for primaite.main.run func.
- Added legacy config file conversion tests
- added

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #915
2023-06-12 18:17:46 +00:00
Chris McCarthy
785409e12a Synced with dev 2023-06-12 16:59:31 +01:00
Chris McCarthy
a08ec8844a Merge remote-tracking branch 'origin/dev' into feature/915_PRI-31_Packaging_Deployment
# Conflicts:
#	docs/source/about.rst
#	src/primaite/main.py
#	src/primaite/nodes/node.py
2023-06-12 16:42:26 +01:00
Brian Kanyora
eac17b6e16 Merged PR 75: Fixing the functionality of resetting a node
## Summary:
Split the ticket into two task

Task 1: Fixed the resetting operating state to set compromised or overwhelmed services or operating system back to a good state. Added a reset count that switches the node into a good state.

Task 2: Created a "SHUTTING DOWN" operating state to last for a (configurable) and a "BOOTING" operating state to last for a (configurable).

## Test process
First test was to test the reset changes the node to a good state when its set to a COMPROMISED state. The last two test makes sure that the node boots and shutdowns correctly.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

Related work items: #898, #1438
2023-06-12 15:21:47 +00:00
Chris McCarthy
8f86bda4d2 Merge remote-tracking branch 'origin/dev' into feature/898-Fix-the-functionality-of-resetting-a-node 2023-06-12 14:20:16 +01:00
Chris McCarthy
3c8a8188fb #951 - Can not view and change the log level from the cli.
- Fixed write transaction issue in transactions_to_file.py
2023-06-09 22:23:45 +01:00
Christopher McCarthy
29d1566789 Apply suggestions from code review 2023-06-09 20:31:12 +00:00
SunilSamra
c5175c500e 901 - added logic to add acls to list (needs more logic adding to it) 2023-06-09 16:56:42 +01:00
Chris McCarthy
605737cd5f #915 - Annotated logs func in cli.py to take -n.
- Fixed entry point on main.py
- Commented out the print reward line in step func of primaite_env.py.
- Added jupyterlab==3.6.1 to pyproject.toml
2023-06-09 16:44:49 +01:00
Chris McCarthy
747ea9d0c6 #915 - Force app dir creation before config file sink 2023-06-09 16:04:56 +01:00
Chris McCarthy
f5e195604f #915 - Synced with dev 2023-06-09 15:49:48 +01:00
SunilSamra
29ba64462a 901 - changed name of enum in enums.py and added class attriubutes in access_control_list.py 2023-06-09 15:45:13 +01:00
SunilSamra
afc133cbc5 901 - added ACL list to observations.py as its own observation space with the ACL attributes and the position of the ACL rule in the ACL list, added ImplicitFirewallRule to enums.py and added acl_implicit_rule, max_acl_list to primaite_env.py 2023-06-09 15:17:20 +01:00
Chris McCarthy
0dbd89e5cb Merge remote-tracking branch 'origin/dev' into feature/915_PRI-31_Packaging_Deployment
# Conflicts:
#	docs/source/about.rst
#	docs/source/config.rst
#	src/primaite/common/config_values_main.py
#	src/primaite/environment/primaite_env.py
#	src/primaite/main.py
#	tests/config/multidiscrete_obs_space_laydown_config.yaml
#	tests/config/obs_tests/laydown.yaml
#	tests/conftest.py
#	tests/test_observation_space.py
2023-06-09 13:41:05 +01:00
Chris McCarthy
af4e71db9b #915 - Synced with dev to bring in changes from #898 2023-06-09 13:11:14 +01:00
SunilSamra
7382ed26b3 901 - changed AccessControlList in access_control_list.py from a dict to a list 2023-06-09 11:25:45 +01:00
Marek Wolan
fd3b304373 Merged PR 69: Configurable observation space.
## Summary
This PR implements a new module called `observations` within `primaite.environment`.

The module is able to keep track of the observation space and to generate observations for the blue agent. It builds the observation space from components. Each component can be configured by supplying parameters at instantiation. For example, the Link Traffic Levels component lets the user customise how many levels there should be.

Note: If a space contains multiple components, they are combined into a 'gym.spaces.Tuple' Space. This is not compatible with some learning agents so we may need to add the options to flatten the observation space.

## Test process
I was able to run the main script with a single-component obs space. I also wrote several unit and integration tests for the new functionality.

## Checklist
- [x] This PR is linked to a **work item**
- [x] I have performed **self-review** of the code
- [x] I have written **tests** for any new functionality added with this PR
- [x] I have updated the **documentation** if this PR changes or adds functionality
- [x] I have run **pre-commit** checks for code style

If you review this, please check the linked tickets and make sure you agree that this PR addresses them fully.

Related work items: #886, #924, #1468, #1469
2023-06-09 09:52:47 +00:00
Chris McCarthy
9b4ed1199b Merge remote-tracking branch 'origin/dev' into feature/915_PRI-31_Packaging_Deployment
# Conflicts:
#	tests/conftest.py
#	tests/test_observation_space.py
#	tests/test_reward.py
2023-06-09 10:35:14 +01:00
Christopher McCarthy
647ba2fcc1 Apply suggestions from code review 2023-06-09 09:31:01 +00:00
Marek Wolan
64bf4bf58a Fix obs tests with new changes 2023-06-09 10:28:24 +01:00
Marek Wolan
b917b65d49 Merge remote-tracking branch 'origin/dev' into feature/1468-observations-class 2023-06-09 09:01:54 +01:00
Sunil Samra
6d502045cb Merged PR 76: 893 - Combine NODE and ACL action spaces into single action space
## Summary
To do this, I have altered `primaite_env` to add the changes from ADSP branch for implementing the `ANY` action space.

It impacts `NODE` and `ACL` action spaces in `primaite_env.py` as all three of them are now discrete action spaces, using dictionary keys to represent different valid actions a node can take on each step.

Previously they were multi-discrete where a single action would look like this `[1,2,1,0]`.

Now an action looks like this, a dictionary entry `{.. 5: [1,2,1,0] ... }` whereby the new action is `5` for example.

It changes the `enums.py` where I added the `ANY` into `ActionType`.

I have also added a package from the ADSP branch agents to add the file utils.py. The file contains functions used by primaite_env.py to decide and check valid actions a node can take and removes the ones which are unnecessary and invalid. This is done for all three types, `NODE`, `ACL` and `ANY`.

## Test process
I have written an unit test in `test_single_action_space.py` which checks the new action space for an `ANY` laydown config has both types of actions in the `action_space` dictionary stored by the environment.

I have written an integration tests to check an agent is carrying out both `NODE` and `ACL` actions in a single episode, where I have hard coded the agent to do two specific things on two different steps.
On one step, I tell the `computer_1` node to turn off one of the nodes and on the other step it creates an ACL rule denying communication between `computer_1` and `switch_1` nodes.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #893, #1429
2023-06-09 07:28:31 +00:00
Chris McCarthy
de86c85b23 #915 - Refactored documentation and included APi docs, dependencies.
- make files now re-build autosummary and deps file.
- Added typer and platformdirs to deps in pyproject.toml.
- Made root_is_pure = True in setup.py as platform/python specific wheels don't need to be built but the option is there should we need to.
-
Added an e2e test for primaite.main.run func.
2023-06-08 15:57:38 +01:00
Chris McCarthy
1809cbe1f4 #915 - typer==0.9.0 added to pyproject.toml 2023-06-08 08:56:39 +01:00
Chris McCarthy
61bd70a6c9 #915 - Ensured LOG_DIR is created so primaite package can be used to perform setup while still logging using primaite logs. 2023-06-08 08:49:06 +01:00
Chris McCarthy
0795a7b4f8 #915 - Ensured primaite setup is carried out on devops pipelines that install primaite. 2023-06-08 08:39:00 +01:00
Chris McCarthy
02e37e5096 # 915 - Fixed issue in conftest.py where session_path and timestamp_str were not being passed to Primaite.
- Also now logging all test outputs to temp directory.
2023-06-07 22:57:37 +01:00
Chris McCarthy
273876873e #915 - Created app dirs and set as constants in the top-level init.
- renamed _config_values_main to training_config.py and renamed the ConfigValuesMain class to TrainingConfig.
Moved training_config.py to src/primaite/config/training_config.py
- Renamed all training config yaml file keys to make creating an instance of TrainingConfig easier.
Moved action_type and num_steps over to the training config.
- Decoupled the training config and lay down config.
- Refactored main.py so that it can be ran from CLI and can take a training config path and a lay down config path.
- refactored all outputs so that they save to the session dir.
- Added some necessary setup scripts that handle creating app dirs, fronting example config files to the user, fronting demo notebooks to the user, performing clean-up in between installations etc.
- Added functions that attempt to retrieve the file path of users example config files that have been fronted by the primaite setup.
- Added logging config and a getLogger function in the top-level init.
- Refactored all logs entries logged to use a logger using the primaite logging config.
- Added basic typer CLI for doing things like setup, viewing logs, viewing primaite version, running a basic session.
- Updated test to use new features and config structures.
- Began updating docs. More to do here.
2023-06-07 22:40:16 +01:00
Marek Wolan
9417cd85ab Apply suggestions from code review. 2023-06-07 15:25:11 +01:00
SunilSamra
6f3e40e390 893 - removed unnecessary functions from utils.py and changed single_action_space_fixed_blue_actions_main_config.yaml back to GENERIC agentIdentifier after PR comments 2023-06-07 14:39:52 +01:00
Marek Wolan
89cea9289b Merge remote-tracking branch 'origin/dev' into feature/1468-observations-class 2023-06-07 14:39:20 +01:00
Brian Kanyora
038abb9be7 feature\898:
Added doc strings
2023-06-07 11:09:00 +01:00
SunilSamra
709fbc500e 893 - removed print statements for demonstration 2023-06-07 09:19:30 +01:00
SunilSamra
57b982eea3 Merge remote-tracking branch 'origin/dev' into feature/893-node-acl-into-one-action-space 2023-06-07 09:18:24 +01:00
Marek Wolan
ef3cef530b Merged PR 61: Fix minor logic errors in main script
This PR fixes some minor issues that I found in the main.py script. Namely:

1. The first observation was always all zeroes when using a generic agent. This is because the `update_environment_obs()` method is not called automatically and is only called by `env.reset()`.
2. The config yaml is never closed as the close function of the file reader was only referenced but never called.

Related work items: #1441
2023-06-06 15:02:40 +00:00
SunilSamra
6cc9516744 893 - added new line for assert statements 2023-06-06 15:54:35 +01:00
Marek Wolan
bfd19280d5 Merge remote-tracking branch 'origin/dev' into bugfix/1441-main-py-minor-bugs 2023-06-06 15:50:35 +01:00
SunilSamra
2eff3912fb 893 - added consistent action for test_reward.py 2023-06-06 13:49:22 +01:00
SunilSamra
69c5c9458b 893 - 2023-06-06 13:47:07 +01:00
SunilSamra
af44b99b6f Merge remote-tracking branch 'origin/dev' into feature/893-node-acl-into-one-action-space 2023-06-06 13:46:01 +01:00
Marek Wolan
c969bc32f5 Merge remote-tracking branch 'origin/dev' into feature/1468-observations-class 2023-06-06 13:40:17 +01:00
Marek Wolan
a987ffb745 Merged PR 62: Make reward calculation consider red POL
Check out the linked bug ticket to understand the issue.

The fix was very simple- just changing which variable is passed to the reward calculation funciton.

Related work items: #1442
2023-06-06 12:27:55 +00:00
SunilSamra
babd4eb5f8 893 - changed action in conftest.py back to sample of the environment action space 2023-06-06 13:23:08 +01:00
SunilSamra
d922d4d054 893 - returned config_values in conftest to move run_generic_set_actions into test_single_action_space.py 2023-06-06 13:21:04 +01:00
SunilSamra
940013f9a6 Merge remote-tracking branch 'origin/feature/893-node-acl-into-one-action-space' into feature/893-node-acl-into-one-action-space 2023-06-06 13:12:58 +01:00
SunilSamra
e15c8c8c89 893 - applied changes raised during PR 2023-06-06 13:12:28 +01:00
Sunil Samra
dcab4b0d4a Apply suggestions from code review 2023-06-06 12:07:22 +00:00
SunilSamra
8558ca1020 893 - updated the docs to reflect changes made to action space 2023-06-06 11:57:04 +01:00
SunilSamra
17d036302f Merge remote-tracking branch 'origin/dev' into feature/893-node-acl-into-one-action-space 2023-06-06 11:56:52 +01:00
SunilSamra
49707b0a17 893 - set the action_space to NOTHING so test_reward.py passes and removed unnecessary test print statements 2023-06-06 11:10:38 +01:00
Brian Kanyora
e52dfababc feature\898:
Fixed the resetting operating state to set compromised or overwhelmed services or operating system back to a good state. Added a reset count that switches the node into a good state.
Created a "SHUTTING DOWN" operating state to last for a (configurable) and a "BOOTING" operating state to last for a (configurable).
Created a test file to test the reset changes the node to a good state when its set to a COMPROMISED state. The last two test tests makes sure that the node boots and shutdowns correctly.
Lastly, updated the docs file as well.
2023-06-06 11:03:43 +01:00
SunilSamra
55f13ae654 893 - added new tests to test action space size and node is completing both sets of actions in a single episode and created new main config 2023-06-06 11:00:41 +01:00
Brian Kanyora
8b61fbebe4 feature\898:
Fixed the resetting operating state to set compromised or overwhelmed services or operating system back to a good state. Added a reset count that switches the node into a good state.
Created a "SHUTTING DOWN" operating state to last for a (configurable) and a "BOOTING" operating state to last for a (configurable).
Created a test file to test the reset changes the node to a good state when its set to a COMPROMISED state. The last two test tests makes sure that the node boots and shutdowns correctly.
Lastly, updated the docs file as well.
2023-06-05 23:59:32 +01:00
Brian Kanyora
a48b217cf3 feature\898: 2023-06-02 16:13:16 +01:00
Brian Kanyora
051cd7da2b Merge branch 'dev' into feature/898-Fix-the-functionality-of-resetting-a-node 2023-06-02 14:56:31 +01:00
Brian Kanyora
e5b60c2f95 feature\898: 2023-06-02 14:54:23 +01:00
Marek Wolan
cdd710d672 Merge remote-tracking branch 'origin/dev' into bugfix/1442-reward-ignores-red-pol 2023-06-02 14:22:45 +01:00
Marek Wolan
1ee6a37188 Merge remote-tracking branch 'origin/dev' into feature/1468-observations-class 2023-06-02 13:41:49 +01:00
Marek Wolan
9d868c5090 Update docs with configurable obs space info 2023-06-02 13:23:03 +01:00
Marek Wolan
25ec0d93a9 Fix Link Traffic Levels observation encoding 2023-06-02 13:15:38 +01:00
Marek Wolan
2330a30021 Get observation tests passing 2023-06-02 13:08:11 +01:00
Marek Wolan
f37b943f7e Add tests for observations 2023-06-02 12:59:01 +01:00
SunilSamra
2c95087056 893 - added test which shows the new action space has been created when ANY is selected in single_action_space_lay_down_config.yaml 2023-06-02 11:55:31 +01:00
SunilSamra
d854773e84 893 - added ANY to enums.py 2023-06-02 09:51:15 +01:00
Marek Wolan
b6ce1cbae9 Edit configs for observation space 2023-06-02 09:10:53 +01:00
Marek Wolan
875562c385 begin updating observations tests 2023-06-01 21:56:05 +01:00
Marek Wolan
85c102cfc1 Update docs page on observations 2023-06-01 21:42:34 +01:00
Marek Wolan
484a31d082 Add docstrings to new observation code 2023-06-01 21:28:38 +01:00
Marek Wolan
c0b214612a Let single-component spaces not use Tuple Spaces 2023-06-01 18:01:47 +01:00
Marek Wolan
3e208bad9b Better Obs default handling 2023-06-01 17:50:18 +01:00
Marek Wolan
7041b79d2a Fix trying to init obs before building network 2023-06-01 17:42:35 +01:00
Marek Wolan
2b25573378 Integrate obs handler with Primaite Env 2023-06-01 16:42:10 +01:00
SunilSamra
8efa0295df 1443 - added in print test statements 2023-06-01 16:27:25 +01:00
Marek Wolan
46352ff9c2 Integrate observation handler with components 2023-06-01 13:28:40 +01:00
Marek Wolan
c276a31b9c Merged PR 65: Add MultiDiscrete observation spaces
**Summary:**

This adds support for the MultiDiscrete observation spaces, the same as what exists in the ADSP branch. The observation space is now configurable in the same way as the action space- by selecting a config item within the laydown config yaml.
The 'box' option has the same behaviour as before.

**Test Process:**

I added two integration tests to ensure that creating the environment is possible with both types of observation space. I also checked that all existing unit tests run fine as long as I update the observation space in the yaml to box.

**Other comments:**
I also updated the documentation relating to observation spaces, please check if the explanation makes sense.

Related work items: #1463
2023-06-01 11:05:00 +00:00
Marek Wolan
c904334c83 Merge branch 'feature/1463-multidiscrete-observation-option' into feature/1468-observations-class 2023-06-01 11:09:21 +01:00
Marek Wolan
3b0d05e9c9 More info in docstring 2023-06-01 11:02:10 +01:00
Marek Wolan
37d606eda6 Separate obs functions and provide docstrings 2023-06-01 10:57:11 +01:00
Marek Wolan
bfd20b7a6b Type hint init_observations return type 2023-06-01 09:57:33 +01:00
Marek Wolan
a0960555fc Fix docstrings to use ReST format 2023-06-01 09:54:45 +01:00
Marek Wolan
76ec9683cb Improve observation space test 2023-06-01 09:45:46 +01:00
Marek Wolan
4ee77656be Merge remote-tracking branch 'origin/dev' into feature/1463-multidiscrete-observation-option 2023-06-01 09:02:48 +01:00
Marek Wolan
6e58c01e8d Start creating observations module 2023-05-31 17:03:53 +01:00
SunilSamra
81e9ddca9b 1443 - reverted changes made to observation space and added config files for testing 2023-05-31 14:11:15 +01:00
SunilSamra
a8cc50a495 Merge remote-tracking branch 'origin/dev' into feature/893-node-acl-into-one-action-space 2023-05-31 13:28:39 +01:00
SunilSamra
9a231821ea 1443 - added changes from ADSP to observation space in primaite_env.py 2023-05-31 13:15:25 +01:00
Marek Wolan
d8cd96100e Merged PR 66: Add a Pull Request template
I wanted to add this pull request template just as a checklist for everyone to ensure they add tests and update documentation.

Do you think it's necessary? Feel free to discuss in the comments of this PR or accept/reject the suggestion.

Related work items: #1467
2023-05-31 11:55:38 +00:00
Marek Wolan
31b5031808 Merge remote-tracking branch 'origin/dev' into bugfix/1441-main-py-minor-bugs 2023-05-31 11:07:06 +01:00
Marek Wolan
5906ed7e39 Merge remote-tracking branch 'origin/dev' into bugfix/1442-reward-ignores-red-pol 2023-05-31 11:04:00 +01:00
Marek Wolan
c6bb855456 Revert unnecessary main.py change 2023-05-31 09:55:28 +00:00
Marek Wolan
2260cb1668 Revert config changes by removing observations 2023-05-31 10:52:57 +01:00
Marek Wolan
65f2d6202f Add default observation type 2023-05-31 10:51:29 +01:00
Marek Wolan
733025bd53 Merge remote-tracking branch 'origin/dev' into feature/1463-multidiscrete-observation-option 2023-05-31 10:46:18 +01:00
SunilSamra
c6db98c1c2 Merge remote-tracking branch 'origin/dev' into feature/893-node-acl-into-one-action-space 2023-05-31 10:34:42 +01:00
Sunil Samra
fbb26bbc63 Merged PR 64: 1443-check-reward-function
In reward.py, the comparisons for the IF statements used when assigning config_values reward values currently compares the initial state to the reference state. However, it should be comparing the reference state (What it should be without any blue/red agent interference) and the final state (state after red and blue actions have taken affect).

Change the IF statement logic to say if `reference_node_os_state` and then in the following IF statement if `final_node_os_state` to compare it.
Do this for all reward functions
Write tests to evaluate step rewards

Related work items: #1443
2023-05-31 09:31:01 +00:00
Marek Wolan
5ea77f3e75 Added pull_request_template.md 2023-05-31 09:26:40 +00:00
Sunil Samra
83694fe537 Apply suggestions from code review 2023-05-31 08:09:09 +00:00
Marek Wolan
045e074d0f Update docs on MultiDiscrete observation spaces. 2023-05-30 16:54:34 +01:00
Marek Wolan
6507529db3 Add test for new multidiscrete spaces 2023-05-30 15:48:11 +01:00
Marek Wolan
fa44dd1a26 Update configs and transactions to include new obs 2023-05-30 15:24:13 +01:00
Marek Wolan
0227769c34 Fix observation node shape 2023-05-30 15:16:14 +01:00
Marek Wolan
375e20a67b Configure observation type MULTIDISCRETE 2023-05-30 15:11:41 +01:00
Marek Wolan
2724838cf8 Setup testing scripts 2023-05-30 13:14:43 +01:00
SunilSamra
91dec9e83d 1443 - updated test_reward.py to reflect updates to reward.py so that the correct config values are called i.e. compromisedShouldBeGood on the correct steps during the training run 2023-05-30 11:50:54 +01:00
SunilSamra
0483eeca82 1443 - changed IF statements from if initial ... if reference to if reference ... if final to compare the final state (state after red and blue actions) with the reference state (state with no red or blue action and with green normal network traffic occurring) 2023-05-30 11:40:40 +01:00
Marek Wolan
77a6fd6aff Make reward calculation consider red POL 2023-05-30 08:50:57 +00:00
Marek Wolan
8a24427bf7 Fix minor logic errors in main script 2023-05-26 14:50:15 +01:00
SunilSamra
dc011a489c 1429 - added code from ADSP branch to primaite_env.py and added NONE = 0 to NodePOLType in enums.py 2023-05-26 14:29:02 +01:00
SunilSamra
9d3d8d5945 1429 - created new branch from dev, added enums to enums.py, created agents package and utils.py file, added option to primaite_env.py for ANY action type and changed the action spaces are defined using ADSP branch 2023-05-26 10:17:45 +01:00
Christopher McCarthy
b255f557db Merged PR 60: #1355 - Carried out full renaming in node.py, active_node.py, passive_node.py...
**The following changes are made to constructor params in the Node class and its children (ActiveNode, PassiveNode, and ServiceNode):**
- _id -> node_id
- _name -> name
- _type -> node_type
- _priority -> priority
- _state -> hardware_state
- _ip_address -> ip_address
- _os_state -> software state
- _file_system_state -> file_system_state
- _config_values -> config_values
- Add type hints to all params.

(node_id, name, and ip_address are str, states and other defines types are the respective enums, leave config_values without a type for now.)

**The following changes are made to instance variables in the Node class and its children:**
- self.type -> self.node_type
- self.operating_state -> self.hardware_state
- self.os_state -> self.software_state
- Add type hints to all instance variables.

(node_id, name, and ip_address are str, states and other defines types are the respective enums, leave config_values without a type for now.)

**The following changes are made to the config files where itemType is NODE:**
- itemType -> item_type
- id -> node_id
- portsList -> ports_list
- serviceList -> service_list
- baseType -> base_type
- nodeType -> node_type
- hardwareState -> hardware_state
- ipAddress -> ip_address
- softwareState -> software_state
- fileSystemState -> file_system_state

**The following changes are made in the primaite/environment/primaite_env.py module:
In the create_node function, the id of the node needs to be retrieved using the new "node_id" key.**
- _id -> node_id
- _name -> name
- _type -> node_type
- _priority -> priority
- _state -> hardware_state
- _ip_address -> ip_address
- _os_state -> software state
- _file_system_state -> file_system_state
- _config_values -> config_values

**Few other cosmetic/code style changes too:**
- Enum classes renamed to use CamelCase.
Started refactoring out unnescessary getters and setters by using `@property` and `@<property name>.setter`.
- Have started to add Type Hints.
- Have started to move docstrings over to the Sphinx ReStructured text format.

Related work items: #1355
2023-05-26 09:01:21 +00:00
Chris McCarthy
05ebd15053 #1355 - Renamed the NodeType custom type in custom_typing.py as it clased with the NodeType enum in enums.py 2023-05-26 09:43:37 +01:00
Chris McCarthy
6245ad9298 #1355 - Carried out full renaming in node.py, active_node.py, passive_node.py, and service_node.py to make params and variable names explicit.
- Made the same renaming in the yaml laydown config files.
- Added Type hints wherever I've been.
- Added a custom NodeType in custom_typing.py to encompass the Union of ActiveNode, PassiveNode, ServiceNode.
2023-05-25 21:03:11 +01:00
Christopher McCarthy
3ac2399115 Merged PR 56: #902 - Fix the reward functionality for node operating system state
#902 - replaced 'final_node_<placeholder>' with 'reference_node_<placeholder>' in methods for scoring of os_state, file_system_state, service state and operating state. This fixed the reward function so it is checked at each step for node operating system state, operating state, file system state and service state.
- Added unit tests.

Related work items: #902
2023-05-25 15:28:19 +00:00
Christopher McCarthy
182bf177a3 Merged PR 59: v1.2.1
v1.2.1
2023-05-25 14:23:55 +00:00
Christopher McCarthy
d3aa69757b Merged PR 58: v1.2.1
v1.2.1
2023-05-25 14:21:16 +00:00
Chris McCarthy
56bce1431b v1.2.1 2023-05-25 15:20:19 +01:00
SunilSamra
fa0e836f65 902 - changed test comment to explain the outcome of the average reward 2023-05-25 14:36:26 +01:00
Christopher McCarthy
e2cc1cb28a Merged PR 57: Resync dev with v1.2.0
Resync dev with v1.2.0
2023-05-25 13:16:29 +00:00
Christopher McCarthy
1d0fd04393 Merged PR 55: Release v1.2.0
Updated artifact-release-pipeline.yaml pipeline to build for Python 3.8 to 3.10 and MacOS, Windows, and Linux.
2023-05-25 13:11:20 +00:00
Chris McCarthy
ddb6adae2b #902 - replaced 'final_node_<placeholder>' with 'reference_node_<placeholder>' in methods for scoring of os_state, file_system_state, service state and operating state. This fixed the reward function so it is checked at each step for node operating system state, operating state, file system state and service state.
- Added unit tests.
2023-05-25 14:05:53 +01:00
Christopher McCarthy
4cabc8a87a Merged PR 54: #1356 - added if statements to set class methods for file system state, os st...
#1356 - added if statements to set class methods for file system state, os state and service states. Refactored file enums.py
- Added unit tests

Related work items: #1356
2023-05-25 12:33:16 +00:00
Chris McCarthy
04c27cc7d5 Updated artifact-release-pipeline.yaml pipeline to build for Python 3.8 to 3.10 and MacOS, Windows, and Linux. 2023-05-25 13:24:49 +01:00
Christopher McCarthy
057fb44061 Merged PR 53: release v1.2.0 2023-05-25 12:03:54 +00:00
Chris McCarthy
51c72aa5be #1356 - added if statements to set class methods for file system state, os state and service states. Refactored file enums.py
- Added unit tests
2023-05-25 13:02:15 +01:00
Christopher McCarthy
769256f0a5 Merged PR 52: #1378 - Ordering of actions in step
#1378 - Re-arranged the action step function in the following order:
1. Implement the Blue Action
2. Perform any time-based activities
3. Apply PoL
4. Implement Red Action
5. Calculate reward signal
6. Output Verbose (currently disabled)
7. Update env_obs
8. Add transaction to the list of transactions

Related work items: #1378
2023-05-25 11:42:19 +00:00
Chris McCarthy
7bbdbd6997 #1378 - Re-added post blue and snapshots 2023-05-25 12:37:42 +01:00
Chris McCarthy
95a0669e5c #1378 - Re-arranged the action step function in the following order:
1. Implement the Blue Action
2. Perform any time-based activities
3. Apply PoL
4. Implement Red Action
5. Calculate reward signal
6. Output Verbose (currently disabled)
7. Update env_obs
8. Add transaction to the list of transactions
2023-05-25 11:58:54 +01:00
Chris McCarthy
71f33ed44e Ran pre-commit hook on all files and performed changes to fix flake8 failures 2023-05-25 11:42:19 +01:00
Chris McCarthy
18f89faf03 Package restructuring and renaming for 1.2.0 2023-05-25 10:52:29 +01:00
Chris McCarthy
9bd7aade43 Package restructuring 2023-05-25 10:31:37 +01:00
Christopher McCarthy
754b16c8c8 Merged PR 7: v1.1.0
v1.1.0
2023-04-06 10:07:40 +00:00
Chris McCarthy
e473c710a2 Bumped version number to 1.1.0 in setup.py 2023-04-06 11:06:55 +01:00
Chris McCarthy
39da5bbe01 Committed the v1.1.0 code provided by James Short. Had to add setuptools==66 to setup.py as older versions of Gym are uninstallable with setuptools>=67 2023-04-06 11:04:09 +01:00
Christopher McCarthy
027709d1e8 Merged PR 2: Added Python version to release pipeline
Added Python version to release pipeline
2023-03-28 16:39:10 +00:00
Chris McCarthy
959b43743c Added Python version to release pipeline 2023-03-28 17:38:30 +01:00
Christopher McCarthy
43a2b1fa3c Merged PR 1: v1.0.0
Initial commit of v1.0.0. Updated the .gitignore for the standard Python gitignore. Added Azure DevOps release pipeline for proper artifact release from the start.
2023-03-28 16:36:07 +00:00
Chris McCarthy
8fc0316253 Initial commit of v1.0.0. Updated the .gitignore for the standard Python gitignore. Added Azure DevOps release pipeline for proper artifact release from the start. 2023-03-28 17:33:34 +01:00
149 changed files with 25818 additions and 20 deletions

13
.flake8 Normal file
View File

@@ -0,0 +1,13 @@
[flake8]
max-line-length=120
extend-ignore =
D105
D107
D100
D104
E203
E712
D401
F811
exclude =
docs/source/*

41
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,41 @@
---
name: Bug report
about: Create a report to help us improve
title: "[BUG] - <bug title goes here>"
labels: bug
assignees: ''
---
### Describe the bug:
A clear and concise description of what the bug is.
### To Reproduce:
Steps to reproduce the behaviour:
1. Import '...'
2. Instantiate '....'
3. Pass to '....'
4. Run '....'
5. See error
### Expected behaviour
A clear and concise description of what you expected to happen.
### Screenshots/Outputs
If applicable, add screenshots to help explain your problem.
### Environment (please complete the following information)
- **OS:** [e.g. Ubuntu 22.04]
- **Python:** [e.g. 3.10.11]
- **PrimAITE Version:** [e.g. v2.0.0]
- **Software:** [e.g. cli, Jupyter, PyCharm, VSCode etc.]
### Additional context
Add any other context about the problem here.

View File

@@ -0,0 +1,24 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[REQUEST] - <request title goes here>"
labels: feature_request
assignees: ''
---
### Is your feature request related to a problem?
If so, please give a concise description of what the problem is. Ex. I'm always frustrated when [...]
### Describe the solution you'd like:
A clear and concise description of what you want to happen.
### Describe alternatives you've considered:
A clear and concise description of any alternative solutions or features you've considered.
### Additional context:
Add any other context or screenshots about the feature request here.

60
.github/workflows/build-sphinx.yml vendored Normal file
View File

@@ -0,0 +1,60 @@
name: build-sphinx-to-github-pages
env:
GITHUB_ACTOR: Autonomous-Resilient-Cyber-Defence
GITHUB_REPOSITORY: Autonomous-Resilient-Cyber-Defence/PrimAITE
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN}}
on:
push:
branches: [main]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install python dev
run: |
set -x
sudo apt-get update
sudo add-apt-repository ppa:deadsnakes/ppa -y
sudo apt install python${{ matrix.python-version}}-dev -y
- name: Install Git
run: |
set -x
sudo apt-get install -y git
shell: bash
- name: Set pip, wheel, setuptools versions
run: |
python -m pip install --upgrade pip==23.0.1
pip install wheel==0.38.4 --upgrade
pip install setuptools==66 --upgrade
pip install build
- name: Install PrimAITE for docs autosummary
run: |
set -x
python -m pip install -e .[dev]
- name: Run build script for Sphinx pages
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -x
bash $PWD/docs/build-sphinx-docs-to-github-pages.sh

66
.github/workflows/python-package.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
name: Python package
on:
push:
branches:
- main
- dev
- dev-gui
- 'release/**'
pull_request:
branches:
- main
- dev
- dev-gui
- 'release/**'
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install python dev
run: |
sudo apt update
sudo add-apt-repository ppa:deadsnakes/ppa -y
sudo apt install python${{ matrix.python-version}}-dev -y
- name: Install Build Dependencies
run: |
python -m pip install --upgrade pip==23.0.1
pip install wheel==0.38.4 --upgrade
pip install setuptools==66 --upgrade
pip install build
- name: Build PrimAITE
run: |
python -m build
- name: Install PrimAITE
run: |
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
python -m pip install $PRIMAITE_WHEEL[dev]
- name: Perform PrimAITE Setup
run: |
primaite setup
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics
- name: Run tests
run: |
pytest tests/

152
.gitignore vendored Normal file
View File

@@ -0,0 +1,152 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
tests/assets/**/*.png
tests/assets/**/tensorboard_logs/
tests/assets/**/checkpoints/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/source/_autosummary
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# IDE
.idea/
docs/source/primaite-dependencies.rst
# outputs
src/primaite/outputs/
# benchmark session outputs
benchmark/output

29
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,29 @@
repos:
- repo: http://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
- id: mixed-line-ending
- id: requirements-txt-fixer
- repo: http://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
args: [ "--line-length=120" ]
additional_dependencies:
- jupyter
- repo: http://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: [ "--profile", "black" ]
- repo: http://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings

90
CHANGELOG.md Normal file
View File

@@ -0,0 +1,90 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [2.0.0] - 2023-07-26
### Added
- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE.
- Application Directories to enable PrimAITE as a Python package with predefined directories for storage.
- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib.
- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`.
- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options.
- Session loading to revisit previously run sessions for SB3 Agents.
- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface.
- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs:
1. Session Metadata
2. Results
3. Diagrams
4. Saved agents (training checkpoints and a final trained agent).
- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup.
- Benchmarking of PrimAITE performance, showcasing session and step durations for reference.
- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`.
### Changed
- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions.
- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`.
- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names.
- Docs and Tests now sit outside the `src/` directory.
- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages.
- All dependencies are now defined in the `pyproject.toml` file.
- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each.
- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management.
- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations.
- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority.
### Fixed
- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation.
- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes.
- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands.
- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making.
## [1.1.1] - 2023-06-27
### Bug Fixes
* Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by:
* Implementing a reference copy of all green IERs (`self.green_iers_reference`).
* Clearing the traffic on reference IERs at the same time as the live IERs.
* Passing the `green_iers_reference` to the `apply_iers` function at the reference stage.
* Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`.
* Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment.
* Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes.
* Removing the unnecessary "Reapply PoL and IERs" action from the step function.
* Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function.
## [1.1.0] - 2023-03-13
### Added
* The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function.
* The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name.
* Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components:
* Blue agent observation space;
* Blue agent action space;
* Reward function;
* Node pattern-of-life.
* The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network).
* New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile.
* NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability.
* The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value.
* The User Guide has been updated to reflect all the above changes.
### Changed
* "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing.
* "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
* "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
* "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file.
* Updates to Transactions.
### Fixed
* Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default.
[unreleased]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/compare/v2.0.0...HEAD
[2.0.0]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/tag/v2.0.0

39
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,39 @@
# How to contribute to PrimAITE?
### **Did you find a bug?**
* **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues).
* If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D+-+%3Cbug+title+goes+here%3E). Be sure to follow our bug report template with the headers **Describe the bug**, **To Reproduce**, **Expected behaviour**, **Screenshots/Outputs**, **Environment**, and **Additional context**
### **Do you have a solution to fix the bug?**
* [Fork the repository](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/fork).
* Install the pre-commit hook with `pre-commit install`.
* Implement the bug fix.
* Update documentation where applicable.
* Update the **UNRELEASED** section of the [CHANGELOG.md](CHANGELOG.md) file
* Write a suitable test/tests.
* Commit the bug fix to the dev branch on your fork. If the bug has an open issue under [Issues](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues), reference the issue in the commit message (e.g. #1 references issue 1).
* Submit a pull request from your dev branch to the Autonomous-Resilient-Cyber-Defence/PrimAITE dev branch. Again, if the bug has an open issue under [Issues](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues), reference the issue in the pull request description.
### **Did you fix whitespace, format code, or make a purely cosmetic patch?**
Changes that are cosmetic in nature and do not add anything substantial to the stability, functionality, or testability of PrimAITE will generally not be accepted.
### **Do you intend to add a new feature or change an existing one?**
* Submit a [feature request issue](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues/new?assignees=&labels=feature_request&projects=&template=feature_request.md&title=%5BREQUEST%5D+-+%3Crequest+title+goes+here%3E).
* Know how to implement the new feature or change? Follow the same steps in the bug fix section above to fork, build, document, test, commit, and submit a pull request.
### **Do you have questions about the source code?**
Ask any question about how to use PrimAITE in our discussions section.
### **Do you want to contribute to the PrimAITE documentation?**
Please follow the "Do you intend to add a new feature or change an existing one?" section above and tag your feature request issue and pull request with the documentation tag.
Thank you from the PrimAITE dev team! 🙌

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 - 2025 Defence Science and Technology Laboratory UK (https://dstl.gov.uk)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

2
MANIFEST.in Normal file
View File

@@ -0,0 +1,2 @@
include src/primaite/setup/_package_data/primaite_config.yaml
include src/primaite/config/_package_data/*.yaml

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

166
README.md
View File

@@ -1,20 +1,146 @@
# Introduction
TODO: Give a short introduction of your project. Let this section explain the objectives or the motivation behind this project.
# Getting Started
TODO: Guide users through getting your code up and running on their own system. In this section you can talk about:
1. Installation process
2. Software dependencies
3. Latest releases
4. API references
# Build and Test
TODO: Describe and show how to build your code and run the tests.
# Contribute
TODO: Explain how other users and developers can contribute to make your code better.
If you want to learn more about creating good readme files then refer the following [guidelines](https://docs.microsoft.com/en-us/azure/devops/repos/git/create-a-readme?view=azure-devops). You can also seek inspiration from the below readme files:
- [ASP.NET Core](https://github.com/aspnet/Home)
- [Visual Studio Code](https://github.com/Microsoft/vscode)
- [Chakra Core](https://github.com/Microsoft/ChakraCore)
# PrimAITE
![image](./PrimAITE_logo_transparent.png)
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for the purposes of training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment, which includes:
- The ability to model a relevant platform / system context;
- The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, traffic loading, operating systems and services;
- Operates at machine-speed to enable fast training cycles.
PrimAITE presents the following features:
- Highly configurable (via YAML files) to provide the means to model a variety of platform / system laydowns, mission profiles and adversarial attack scenarios;
- A Reinforcement Learning (RL) reward function based on (a) the ability to counter the specific modelled adversarial cyber-attack, and (b) the ability to ensure mission success;
- Provision of logging to support AI evaluation and metrics gathering;
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life, adversarial behaviour and mission data (on a sliding scale of criticality);
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
- Application of IERs to the platform / system laydown adheres to the ACL ruleset;
- Presents an OpenAI gym or RLLib interface to the environment, allowing integration with any compliant defensive agents;
- Full capture of discrete logs relating to agent training (full system state, agent actions taken, instantaneous and average reward for every step of every episode);
- NetworkX provides laydown visualisation capability.
## Getting Started with PrimAITE
### 💫 Install & Run
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
Currently, the PrimAITE wheel can only be installed from GitHub. This may change in the future with release to PyPi.
#### Windows (PowerShell)
**Prerequisites:**
* Manual install of Python >= 3.8 < 3.11
**Install:**
``` powershell
mkdir ~\primaite
cd ~\primaite
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
.\.venv\Scripts\activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
primaite setup
```
**Run:**
``` bash
primaite session
```
#### Unix
**Prerequisites:**
* Manual install of Python >= 3.8 < 3.11
``` bash
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt install python3.10
sudo apt-get install python3-pip
sudo apt-get install python3-venv
```
**Install:**
``` bash
mkdir ~/primaite
cd ~/primaite
python3 -m venv .venv
source .venv/bin/activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
primaite setup
```
**Run:**
``` bash
primaite session
```
### Developer Install from Source
To make your own changes to PrimAITE, perform the install from source (developer install)
#### 1. Clone the PrimAITE repository
``` unix
git clone git@github.com:Autonomous-Resilient-Cyber-Defence/PrimAITE.git
```
#### 2. CD into the repo directory
``` unix
cd PrimAITE
```
#### 3. Create a new python virtual environment (venv)
```unix
python3 -m venv venv
```
#### 4. Activate the venv
##### Unix
```bash
source venv/bin/activate
```
##### Windows (Powershell)
```powershell
.\venv\Scripts\activate
```
#### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies
```bash
python3 -m pip install -e .[dev]
```
#### 6. Perform the PrimAITE setup:
```bash
primaite setup
```
## 📚 Building documentation
The PrimAITE documentation can be built with the following commands:
##### Unix
```bash
cd docs
make html
```
##### Windows (Powershell)
```powershell
cd docs
.\make.bat html
```

View File

@@ -0,0 +1,164 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
flatten: true
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 500
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,449 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
import platform
import shutil
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Optional, Tuple, Union
from unittest.mock import patch
import GPUtil
import plotly.graph_objects as go
import polars as pl
import psutil
import yaml
from plotly.graph_objs import Figure
from pylatex import Command, Document
from pylatex import Figure as LatexFigure
from pylatex import Section, Subsection, Tabular
from pylatex.utils import bold
import primaite
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.data_viz.session_plots import get_plotly_config
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
_LOGGER = primaite.getLogger(__name__)
_BENCHMARK_ROOT = Path(__file__).parent
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
# Clear and recreate the output directory
if _OUTPUT_ROOT.exists():
shutil.rmtree(_OUTPUT_ROOT)
_OUTPUT_ROOT.mkdir()
_TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.yaml"
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
def get_size(size_bytes: int):
"""
Scale bytes to its proper format.
e.g:
1253656 => '1.20MB'
1253656678 => '1.17GB'
:
"""
factor = 1024
for unit in ["", "K", "M", "G", "T", "P"]:
if size_bytes < factor:
return f"{size_bytes:.2f}{unit}B"
size_bytes /= factor
def _get_system_info() -> Dict:
"""Builds and returns a dict containing system info."""
uname = platform.uname()
cpu_freq = psutil.cpu_freq()
virtual_mem = psutil.virtual_memory()
swap_mem = psutil.swap_memory()
gpus = GPUtil.getGPUs()
return {
"System": {
"OS": uname.system,
"OS Version": uname.version,
"Machine": uname.machine,
"Processor": uname.processor,
},
"CPU": {
"Physical Cores": psutil.cpu_count(logical=False),
"Total Cores": psutil.cpu_count(logical=True),
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
},
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
}
def _build_benchmark_latex_report(
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
):
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
# Create a new document
doc = Document("report", geometry_options=geometry_options)
# Title
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
doc.preamble.append(Command("date", datetime.now().date()))
doc.append(Command("maketitle"))
sessions = data["total_sessions"]
episodes = data["training_config"]["num_train_episodes"]
steps = data["training_config"]["num_train_steps"]
# Body
with doc.create(Section("Introduction")):
doc.append(
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
)
doc.append(
f"\nThe benchmarking process consists of running {sessions} training session using the same "
f"training and lay down config files. Each session trains an agent for {episodes} episodes, "
f"with each episode consisting of {steps} steps."
)
doc.append(
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
f"further smoothing."
)
with doc.create(Section("System Information")):
with doc.create(Subsection("Python")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
table.add_row((bold("Version"), sys.version))
table.add_hline()
for section, section_data in data["system_info"].items():
if section_data:
with doc.create(Subsection(section)):
if isinstance(section_data, dict):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for key, value in section_data.items():
table.add_row((bold(key), value))
table.add_hline()
elif isinstance(section_data, list):
headers = section_data[0].keys()
tabs_str = "|".join(["l" for _ in range(len(headers))])
tabs_str = f"|{tabs_str}|"
with doc.create(Tabular(tabs_str)) as table:
table.add_hline()
table.add_row([bold(h) for h in headers])
table.add_hline()
for item in section_data:
table.add_row(item.values())
table.add_hline()
headers_map = {
"total_sessions": "Total Sessions",
"total_episodes": "Total Episodes",
"total_time_steps": "Total Steps",
"av_s_per_session": "Av Session Duration (s)",
"av_s_per_step": "Av Step Duration (s)",
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
}
with doc.create(Section("Stats")):
with doc.create(Subsection("Benchmark Results")):
with doc.create(Tabular("|l|l|")) as table:
table.add_hline()
for section, header in headers_map.items():
if section.startswith("av_"):
table.add_row((bold(header), f"{data[section]:.4f}"))
else:
table.add_row((bold(header), data[section]))
table.add_hline()
with doc.create(Section("Graphs")):
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(this_version_plot_path))
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
with doc.create(LatexFigure(position="h!")) as pic:
pic.add_image(str(all_version_plot_path))
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
class BenchmarkPrimaiteSession(PrimaiteSession):
"""A benchmarking primaite session."""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
super().__init__(training_config_path, lay_down_config_path)
self.setup()
@property
def env(self) -> Primaite:
"""Direct access to the env for ease of testing."""
return self._agent_session._env # noqa
def __enter__(self):
return self
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
"""
Calculate and return the learning benchmark durations.
Calculates the:
- Total learning time in seconds
- Total learning time per time step in seconds
- Total learning time per 100 time steps per 10 nodes in seconds
:return: The learning benchmark durations as a Tuple of three floats:
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
"""
data = self.metadata_file_as_dict()
start_dt = datetime.fromisoformat(data["start_datetime"])
end_dt = datetime.fromisoformat(data["end_datetime"])
delta = end_dt - start_dt
total_s = delta.total_seconds()
total_steps = data["learning"]["total_time_steps"]
s_per_step = total_s / total_steps
num_nodes = self.env.num_nodes
num_intervals = total_steps / 100
av_interval_time = total_s / num_intervals
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
return total_s, s_per_step, s_per_100_steps_10_nodes
def learn_metadata_dict(self) -> Dict[str, Any]:
"""Metadata specific to the learning session."""
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
return {
"total_episodes": self.env.actual_episode_count,
"total_time_steps": self.env.total_step_count,
"total_s": total_s,
"s_per_step": s_per_step,
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
"av_reward_per_episode": self.learn_av_reward_per_episode_dict(),
}
def _get_benchmark_session_path(session_timestamp: datetime) -> Path:
return _OUTPUT_ROOT / session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
def _get_benchmark_primaite_session() -> BenchmarkPrimaiteSession:
with patch("primaite.agents.agent_abc.get_session_path", _get_benchmark_session_path) as mck:
mck.session_timestamp = datetime.now()
return BenchmarkPrimaiteSession(_TRAINING_CONFIG_PATH, _LAY_DOWN_CONFIG_PATH)
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) -> dict:
n = len(metadata_dict)
with open(_TRAINING_CONFIG_PATH, "r") as file:
training_config_dict = yaml.safe_load(file)
with open(_LAY_DOWN_CONFIG_PATH, "r") as file:
lay_down_config_dict = yaml.safe_load(file)
averaged_data = {
"start_timestamp": start_datetime.isoformat(),
"end_datetime": datetime.now().isoformat(),
"primaite_version": primaite.__version__,
"system_info": _get_system_info(),
"total_sessions": n,
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / n,
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / n,
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values()) / n,
"combined_av_reward_per_episode": {},
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
"training_config": training_config_dict,
"lay_down_config": lay_down_config_dict,
}
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
for episode in episodes:
combined_av_reward = sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / n
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
return averaged_data
def _get_df_from_episode_av_reward_dict(data: Dict):
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
return (
pl.from_dict(data)
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
.rename({"rolling_mean": "rolling_av_reward"})
)
def _plot_benchmark_metadata(
benchmark_metadata_dict: Dict,
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
fig.add_trace(
go.Scatter(
x=df["episode"],
y=df["av_reward"],
mode="lines",
name=f"Session {session}",
opacity=0.25,
line={"color": "#a6a6a6"},
)
)
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
fig.add_trace(
go.Scatter(
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
)
)
fig.add_trace(
go.Scatter(
x=df["episode"],
y=df["rolling_av_reward"],
mode="lines",
name="Rolling Av (Combined Session Av)",
line={"color": "#4CBB17"},
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
},
yaxis={"title": "Average Reward"},
title=title,
)
return fig
def _plot_all_benchmarks_combined_session_av():
"""
Plot the Benchmark results for each released version of PrimAITE.
Does this by iterating over the ``benchmark/results`` directory and
extracting the benchmark metadata json for each version that has been
benchmarked. The combined_av_reward_per_episode is extracted from each,
converted into a polars dataframe, and plotted as a scatter line in plotly.
"""
title = "PrimAITE Versions Learning Benchmark"
subtitle = "Rolling Av (Combined Session Av)"
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
for dir in _RESULTS_ROOT.iterdir():
if dir.is_dir():
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
with open(metadata_file, "r") as file:
metadata_dict = json.load(file)
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
},
yaxis={"title": "Average Reward"},
title=title,
)
fig["data"][0]["showlegend"] = True
return fig
def run():
"""Run the PrimAITE benchmark."""
start_datetime = datetime.now()
av_reward_per_episode_dicts = {}
for i in range(1, 11):
print(f"Starting Benchmark Session: {i}")
with _get_benchmark_primaite_session() as session:
session.learn()
av_reward_per_episode_dicts[i] = session.learn_metadata_dict()
benchmark_metadata = _build_benchmark_results_dict(
start_datetime=start_datetime, metadata_dict=av_reward_per_episode_dicts
)
v_str = f"v{primaite.__version__}"
version_result_dir = _RESULTS_ROOT / v_str
if version_result_dir.exists():
shutil.rmtree(version_result_dir)
version_result_dir.mkdir(exist_ok=True, parents=True)
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
json.dump(benchmark_metadata, file, indent=4)
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
fig = _plot_benchmark_metadata(benchmark_metadata, title=title)
this_version_plot_path = version_result_dir / f"{title}.png"
fig.write_image(this_version_plot_path)
fig = _plot_all_benchmarks_combined_session_av()
all_version_plot_path = _RESULTS_ROOT / "PrimAITE Versions Learning Benchmark.png"
fig.write_image(all_version_plot_path)
_build_benchmark_latex_report(benchmark_metadata, this_version_plot_path, all_version_plot_path)
if __name__ == "__main__":
run()

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 225 KiB

File diff suppressed because it is too large Load Diff

521
diagram/classes.puml Normal file
View File

@@ -0,0 +1,521 @@
@startuml classes
set namespaceSeparator none
class "ACLRule" as primaite.acl.acl_rule.ACLRule {
dest_ip : str
permission
port : str
protocol : str
source_ip : str
get_dest_ip() -> str
get_permission() -> str
get_port() -> str
get_protocol() -> str
get_source_ip() -> str
}
class "AbstractObservationComponent" as primaite.environment.observations.AbstractObservationComponent {
current_observation : NotImplementedType, ndarray
env : str
space : Space
structure : List[str]
{abstract}generate_structure() -> List[str]
{abstract}update() -> None
}
class "AccessControlList" as primaite.acl.access_control_list.AccessControlList {
acl
acl_implicit_permission
acl_implicit_rule
max_acl_rules : int
add_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: str) -> None
check_address_match(_rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool
get_dictionary_hash(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int
get_relevant_rules(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> Dict[int, ACLRule]
is_blocked(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool
remove_all_rules() -> None
remove_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None
}
class "AccessControlList_" as primaite.environment.observations.AccessControlList_ {
current_observation : ndarray
space : MultiDiscrete
structure : list
generate_structure() -> List[str]
update() -> None
}
class "ActiveNode" as primaite.nodes.active_node.ActiveNode {
file_system_action_count : int
file_system_scanning : bool
file_system_scanning_count : int
file_system_state_actual : GOOD
file_system_state_observed : REPAIRING, RESTORING, GOOD
ip_address : str
patching_count : int
software_state
software_state : GOOD
set_file_system_state(file_system_state: FileSystemState) -> None
set_file_system_state_if_not_compromised(file_system_state: FileSystemState) -> None
set_software_state_if_not_compromised(software_state: SoftwareState) -> None
start_file_system_scan() -> None
update_booting_status() -> None
update_file_system_state() -> None
update_os_patching_status() -> None
update_resetting_status() -> None
}
class "AgentSessionABC" as primaite.agents.agent_abc.AgentSessionABC {
checkpoints_path
evaluation_path
is_eval : bool
learning_path
sb3_output_verbose_level : NONE
session_path : Union[str, Path]
session_timestamp : datetime
timestamp_str
uuid
close() -> None
{abstract}evaluate() -> None
{abstract}export() -> None
{abstract}learn() -> None
load(path: Union[str, Path]) -> None
{abstract}save() -> None
}
class "DoNothingACLAgent" as primaite.agents.simple.DoNothingACLAgent {
}
class "DoNothingNodeAgent" as primaite.agents.simple.DoNothingNodeAgent {
}
class "DummyAgent" as primaite.agents.simple.DummyAgent {
}
class "HardCodedACLAgent" as primaite.agents.hardcoded_acl.HardCodedACLAgent {
get_allow_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule]
get_allow_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
get_blocked_green_iers(green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[str, IER]
get_blocking_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
get_deny_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule]
get_matching_acl_rules(source_node_id: str, dest_node_id: str, protocol: str, port: str, acl: AccessControlList, nodes: Dict[str, Union[ServiceNode, ActiveNode]], services_list: List[str]) -> Dict[int, ACLRule]
get_matching_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
}
class "HardCodedAgentSessionABC" as primaite.agents.hardcoded_abc.HardCodedAgentSessionABC {
is_eval : bool
evaluate() -> None
export() -> None
learn() -> None
load(path: Union[str, Path]) -> None
save() -> None
}
class "HardCodedNodeAgent" as primaite.agents.hardcoded_node.HardCodedNodeAgent {
}
class "IER" as primaite.pol.ier.IER {
dest_node_id : str
end_step : int
id : str
load : int
mission_criticality : int
port : str
protocol : str
running : bool
source_node_id : str
start_step : int
get_dest_node_id() -> str
get_end_step() -> int
get_id() -> str
get_is_running() -> bool
get_load() -> int
get_mission_criticality() -> int
get_port() -> str
get_protocol() -> str
get_source_node_id() -> str
get_start_step() -> int
set_is_running(_value: bool) -> None
}
class "Link" as primaite.links.link.Link {
bandwidth : int
dest_node_name : str
id : str
protocol_list : List[Protocol]
source_node_name : str
add_protocol(_protocol: str) -> None
add_protocol_load(_protocol: str, _load: int) -> None
clear_traffic() -> None
get_bandwidth() -> int
get_current_load() -> int
get_dest_node_name() -> str
get_id() -> str
get_protocol_list() -> List[Protocol]
get_source_node_name() -> str
}
class "LinkTrafficLevels" as primaite.environment.observations.LinkTrafficLevels {
current_observation : ndarray
space : MultiDiscrete
structure : list
generate_structure() -> List[str]
update() -> None
}
class "Node" as primaite.nodes.node.Node {
booting_count : int
config_values
hardware_state : BOOTING, ON, RESETTING, OFF
name : Final[str]
node_id : Final[str]
node_type : Final[NodeType]
priority
resetting_count : int
shutting_down_count : int
reset() -> None
turn_off() -> None
turn_on() -> None
update_booting_status() -> None
update_resetting_status() -> None
update_shutdown_status() -> None
}
class "NodeLinkTable" as primaite.environment.observations.NodeLinkTable {
current_observation : ndarray
space : Box
structure : list
generate_structure() -> List[str]
update() -> None
}
class "NodeStateInstructionGreen" as primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen {
end_step : int
id : str
node_id : str
node_pol_type : str
service_name : str
start_step : int
state : Union['HardwareState', 'SoftwareState', 'FileSystemState']
get_end_step() -> int
get_node_id() -> str
get_node_pol_type() -> 'NodePOLType'
get_service_name() -> str
get_start_step() -> int
get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState']
}
class "NodeStateInstructionRed" as primaite.nodes.node_state_instruction_red.NodeStateInstructionRed {
end_step : int
id : str
initiator : str
pol_type
service_name : str
source_node_id : str
source_node_service : str
source_node_service_state : str
start_step : int
state : Union['HardwareState', 'SoftwareState', 'FileSystemState']
target_node_id : str
get_end_step() -> int
get_initiator() -> 'NodePOLInitiator'
get_pol_type() -> NodePOLType
get_service_name() -> str
get_source_node_id() -> str
get_source_node_service() -> str
get_source_node_service_state() -> str
get_start_step() -> int
get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState']
get_target_node_id() -> str
}
class "NodeStatuses" as primaite.environment.observations.NodeStatuses {
current_observation : ndarray
space : MultiDiscrete
structure : list
generate_structure() -> List[str]
update() -> None
}
class "ObservationsHandler" as primaite.environment.observations.ObservationsHandler {
current_observation
registered_obs_components : List[AbstractObservationComponent]
space
deregister(obs_component: AbstractObservationComponent) -> None
describe_structure() -> List[str]
from_config(env: 'Primaite', obs_space_config: dict) -> 'ObservationsHandler'
register(obs_component: AbstractObservationComponent) -> None
update_obs() -> None
update_space() -> None
}
class "PassiveNode" as primaite.nodes.passive_node.PassiveNode {
ip_address
}
class "Primaite" as primaite.environment.primaite_env.Primaite {
ACTION_SPACE_ACL_ACTION_VALUES : int
ACTION_SPACE_ACL_PERMISSION_VALUES : int
ACTION_SPACE_NODE_ACTION_VALUES : int
ACTION_SPACE_NODE_PROPERTY_VALUES : int
acl
action_dict : dict, Dict[int, List[int]]
action_space : Discrete, Space
action_type : int
actual_episode_count
agent_identifier
average_reward : float
env_obs : ndarray, tuple
episode_av_reward_writer
episode_count : int
episode_steps : int
green_iers : Dict[str, IER]
green_iers_reference : Dict[str, IER]
lay_down_config
links : Dict[str, Link]
links_post_blue : dict
links_post_pol : dict
links_post_red : dict
links_reference : Dict[str, Link]
max_number_acl_rules : int
network : Graph
network_reference : Graph
node_pol : Dict[str, NodeStateInstructionGreen]
nodes : Dict[str, NodeUnion]
nodes_post_blue : dict
nodes_post_pol : dict
nodes_post_red : dict
nodes_reference : Dict[str, NodeUnion]
num_links : int
num_nodes : int
num_ports : int
num_services : int
obs_config : dict
obs_handler
observation_space : Tuple, Box, Space
observation_type
ports_list : List[str]
red_iers : Dict[str, IER], dict
red_node_pol : dict, Dict[str, NodeStateInstructionRed]
services_list : List[str]
session_path : Final[Path]
step_count : int
step_info : Dict[Any]
timestamp_str : Final[str]
total_reward : float
total_step_count : int
training_config
transaction_writer
apply_actions_to_acl(_action: int) -> None
apply_actions_to_nodes(_action: int) -> None
apply_time_based_updates() -> None
close() -> None
create_acl_action_dict() -> Dict[int, List[int]]
create_acl_rule(item: Dict) -> None
create_green_ier(item: Dict) -> None
create_green_pol(item: Dict) -> None
create_link(item: Dict) -> None
create_node(item: Dict) -> None
create_node_action_dict() -> Dict[int, List[int]]
create_node_and_acl_action_dict() -> Dict[int, List[int]]
create_ports_list(ports: Dict) -> None
create_red_ier(item: Dict) -> None
create_red_pol(item: Dict) -> None
create_services_list(services: Dict) -> None
get_action_info(action_info: Dict) -> None
get_observation_info(observation_info: Dict) -> None
init_acl() -> None
init_observations() -> Tuple[spaces.Space, np.ndarray]
interpret_action_and_apply(_action: int) -> None
load_lay_down_config() -> None
output_link_status() -> None
reset() -> np.ndarray
reset_environment() -> None
reset_node(item: Dict) -> None
save_obs_config(obs_config: dict) -> None
set_as_eval() -> None
step(action: int) -> Tuple[np.ndarray, float, bool, Dict]
update_environent_obs() -> None
}
class "PrimaiteSession" as primaite.primaite_session.PrimaiteSession {
evaluation_path : Optional[Path], Path
is_load_session : bool
learning_path : Optional[Path], Path
session_path : Optional[Path], Path
timestamp_str : str, Optional[str]
close() -> None
evaluate() -> None
learn() -> None
setup() -> None
}
class "Protocol" as primaite.common.protocol.Protocol {
load : int
name : str
add_load(_load: int) -> None
clear_load() -> None
get_load() -> int
get_name() -> str
}
class "RLlibAgent" as primaite.agents.rllib.RLlibAgent {
{abstract}evaluate() -> None
{abstract}export() -> None
learn() -> None
{abstract}load(path: Union[str, Path]) -> RLlibAgent
save(overwrite_existing: bool) -> None
}
class "RandomAgent" as primaite.agents.simple.RandomAgent {
}
class "SB3Agent" as primaite.agents.sb3.SB3Agent {
is_eval : bool
evaluate() -> None
{abstract}export() -> None
learn() -> None
save() -> None
}
class "Service" as primaite.common.service.Service {
name : str
patching_count : int
port : str
software_state : GOOD
reduce_patching_count() -> None
}
class "ServiceNode" as primaite.nodes.service_node.ServiceNode {
services : Dict[str, Service]
add_service(service: Service) -> None
get_service_state(protocol_name: str) -> SoftwareState
has_service(protocol_name: str) -> bool
service_is_overwhelmed(protocol_name: str) -> bool
service_running(protocol_name: str) -> bool
set_service_state(protocol_name: str, software_state: SoftwareState) -> None
set_service_state_if_not_compromised(protocol_name: str, software_state: SoftwareState) -> None
update_booting_status() -> None
update_resetting_status() -> None
update_services_patching_status() -> None
}
class "SessionOutputWriter" as primaite.utils.session_output_writer.SessionOutputWriter {
learning_session : bool
transaction_writer : bool
close() -> None
write(data: Union[Tuple, Transaction]) -> None
}
class "TrainingConfig" as primaite.config.training_config.TrainingConfig {
action_type
agent_framework
agent_identifier
agent_load_file : Optional[str]
all_ok : float
checkpoint_every_n_episodes : int
compromised : float
compromised_should_be_good : float
compromised_should_be_overwhelmed : float
compromised_should_be_patching : float
corrupt : float
corrupt_should_be_destroyed : float
corrupt_should_be_good : float
corrupt_should_be_repairing : float
corrupt_should_be_restoring : float
deep_learning_framework
destroyed : float
destroyed_should_be_corrupt : float
destroyed_should_be_good : float
destroyed_should_be_repairing : float
destroyed_should_be_restoring : float
deterministic : bool
file_system_repairing_limit : int
file_system_restoring_limit : int
file_system_scanning_limit : int
good_should_be_compromised : float
good_should_be_corrupt : float
good_should_be_destroyed : float
good_should_be_overwhelmed : float
good_should_be_patching : float
good_should_be_repairing : float
good_should_be_restoring : float
green_ier_blocked : float
hard_coded_agent_view
implicit_acl_rule
load_agent : bool
max_number_acl_rules : int
node_booting_duration : int
node_reset_duration : int
node_shutdown_duration : int
num_eval_episodes : int
num_eval_steps : int
num_train_episodes : int
num_train_steps : int
observation_space : dict
observation_space_high_value : int
off_should_be_on : float
off_should_be_resetting : float
on_should_be_off : float
on_should_be_resetting : float
os_patching_duration : int
overwhelmed : float
overwhelmed_should_be_compromised : float
overwhelmed_should_be_good : float
overwhelmed_should_be_patching : float
patching : float
patching_should_be_compromised : float
patching_should_be_good : float
patching_should_be_overwhelmed : float
random_red_agent : bool
red_ier_running : float
repairing : float
repairing_should_be_corrupt : float
repairing_should_be_destroyed : float
repairing_should_be_good : float
repairing_should_be_restoring : float
resetting : float
resetting_should_be_off : float
resetting_should_be_on : float
restoring : float
restoring_should_be_corrupt : float
restoring_should_be_destroyed : float
restoring_should_be_good : float
restoring_should_be_repairing : float
sb3_output_verbose_level
scanning : float
seed : Optional[int]
service_patching_duration : int
session_type
time_delay : int
from_dict(config_dict: Dict[str, Any]) -> TrainingConfig
to_dict(json_serializable: bool) -> Dict
}
class "Transaction" as primaite.transactions.transaction.Transaction {
action_space : Optional[int]
agent_identifier
episode_number : int
obs_space : str
obs_space_description : NoneType, Optional[List[str]], list
obs_space_post : Optional[Union['np.ndarray', Tuple['np.ndarray']]]
obs_space_pre : Optional[Union['np.ndarray', Tuple['np.ndarray']]]
reward : Optional[float], float
step_number : int
timestamp : datetime
as_csv_data() -> Tuple[List, List]
}
primaite.agents.hardcoded_abc.HardCodedAgentSessionABC --|> primaite.agents.agent_abc.AgentSessionABC
primaite.agents.hardcoded_acl.HardCodedACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.hardcoded_node.HardCodedNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.rllib.RLlibAgent --|> primaite.agents.agent_abc.AgentSessionABC
primaite.agents.sb3.SB3Agent --|> primaite.agents.agent_abc.AgentSessionABC
primaite.agents.simple.DoNothingACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.simple.DoNothingNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.simple.DummyAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.agents.simple.RandomAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
primaite.environment.observations.AccessControlList_ --|> primaite.environment.observations.AbstractObservationComponent
primaite.environment.observations.LinkTrafficLevels --|> primaite.environment.observations.AbstractObservationComponent
primaite.environment.observations.NodeLinkTable --|> primaite.environment.observations.AbstractObservationComponent
primaite.environment.observations.NodeStatuses --|> primaite.environment.observations.AbstractObservationComponent
primaite.nodes.active_node.ActiveNode --|> primaite.nodes.node.Node
primaite.nodes.passive_node.PassiveNode --|> primaite.nodes.node.Node
primaite.nodes.service_node.ServiceNode --|> primaite.nodes.active_node.ActiveNode
primaite.common.service.Service --|> primaite.nodes.service_node.ServiceNode
primaite.acl.access_control_list.AccessControlList --* primaite.environment.primaite_env.Primaite : acl
primaite.acl.acl_rule.ACLRule --* primaite.acl.access_control_list.AccessControlList : acl_implicit_rule
primaite.agents.hardcoded_acl.HardCodedACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.hardcoded_node.HardCodedNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.rllib.RLlibAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.sb3.SB3Agent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.DoNothingACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.DoNothingNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.DummyAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.agents.simple.RandomAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
primaite.config.training_config.TrainingConfig --* primaite.agents.agent_abc.AgentSessionABC : _training_config
primaite.config.training_config.TrainingConfig --* primaite.environment.primaite_env.Primaite : training_config
primaite.environment.observations.ObservationsHandler --* primaite.environment.primaite_env.Primaite : obs_handler
primaite.environment.primaite_env.Primaite --* primaite.agents.agent_abc.AgentSessionABC : _env
primaite.environment.primaite_env.Primaite --* primaite.agents.hardcoded_abc.HardCodedAgentSessionABC : _env
primaite.environment.primaite_env.Primaite --* primaite.agents.sb3.SB3Agent : _env
primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : episode_av_reward_writer
primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : transaction_writer
primaite.config.training_config.TrainingConfig --o primaite.nodes.node.Node : config_values
primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen --* primaite.environment.primaite_env.Primaite
primaite.nodes.node_state_instruction_red.NodeStateInstructionRed --* primaite.environment.primaite_env.Primaite
primaite.pol.ier.IER --* primaite.environment.primaite_env.Primaite
primaite.common.protocol.Protocol --o primaite.links.link.Link
primaite.links.link.Link --* primaite.environment.primaite_env.Primaite
primaite.config.training_config.TrainingConfig --o primaite.nodes.active_node.ActiveNode
primaite.utils.session_output_writer.SessionOutputWriter --> primaite.transactions.transaction.Transaction
primaite.transactions.transaction.Transaction --> primaite.environment.primaite_env.Primaite
@enduml

34
docs/Makefile Normal file
View File

@@ -0,0 +1,34 @@
# Minimal makefile for Sphinx documentation
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
AUTOSUMMARY="source\_autosummary"
# Remove command is different depending on OS
ifdef OS
RM = IF exist $(AUTOSUMMARY) ( RMDIR $(AUTOSUMMARY) /s /q )
else
ifeq ($(shell uname), Linux)
RM = rm -rf $(AUTOSUMMARY)
endif
endif
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
clean:
$(RM)
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile | clean
pip-licenses --format=rst --with-urls --output-file=source/primaite-dependencies.rst
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

0
docs/_static/.gitkeep vendored Normal file
View File

View File

@@ -0,0 +1,41 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
:members:
:show-inheritance:
:inherited-members:
:special-members: __init__, __call__, __add__, __mul__
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
:nosignatures:
{% for item in methods %}
{%- if not item.startswith('_') %}
~{{ name }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}
{% block attributes %}
{% if attributes %}
.. rubric:: {{ _('Attributes') }}
.. autosummary::
{% for item in attributes %}
~{{ name }}.{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@@ -0,0 +1,73 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
..
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..
{{ fullname | escape | underline}}
.. automodule:: {{ fullname }}
{% block attributes %}
{% if attributes %}
.. rubric:: Module attributes
.. autosummary::
:toctree:
{% for item in attributes %}
{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block functions %}
{% if functions %}
.. rubric:: {{ _('Functions') }}
.. autosummary::
:toctree:
:nosignatures:
{% for item in functions %}
{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block classes %}
{% if classes %}
.. rubric:: {{ _('Classes') }}
.. autosummary::
:toctree:
:template: custom-class-template.rst
:nosignatures:
{% for item in classes %}
{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block exceptions %}
{% if exceptions %}
.. rubric:: {{ _('Exceptions') }}
.. autosummary::
:toctree:
{% for item in exceptions %}
{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block modules %}
{% if modules %}
.. autosummary::
:toctree:
:template: custom-module-template.rst
:recursive:
{% for item in modules %}
{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}

20
docs/api.rst Normal file
View File

@@ -0,0 +1,20 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
..
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden
(not declared in any toctree) to remove an unnecessary intermediate page; index.rst instead points directly to the
package page. DO NOT REMOVE THIS FILE!
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
..
.. autosummary::
:toctree: source/_autosummary
:template: custom-module-template.rst
:recursive:
primaite
tests

View File

@@ -0,0 +1,67 @@
#!/bin/bash
set -x
apt-get update
apt-get -y install git rsync python3-sphinx
pwd ls -lah
export SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct)
##############
# BUILD DOCS #
##############
cd docs
# Python Sphinx, configured with source/conf.py
# See https://www.sphinx-doc.org/
make clean
make html
cd ..
#######################
# Update GitHub Pages #
#######################
git config --global user.name "${GITHUB_ACTOR}"
git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
docroot=`mktemp -d`
rsync -av $PWD/docs/_build/html/ "${docroot}/"
pushd "${docroot}"
git init
git remote add deploy "https://token:${GITHUB_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
git checkout -b sphinx-docs-github-pages
# Adds .nojekyll file to the root to signal to GitHub that
# directories that start with an underscore (_) can remain
touch .nojekyll
# Add README
cat > README.md <<EOF
# README for the Sphinx Docs GitHub Pages Branch
This branch is simply a cache for the website served from https://Autonomous-Resilient-Cyber-Defence.github.io/PrimAITE/,
and is not intended to be viewed on github.com.
For more information on how this site is built using Sphinx, Read the Docs, GitHub Actions/Pages, and demo
implementation from https://github.com/annegentle, see:
* https://www.docslikecode.com/articles/github-pages-python-sphinx/
* https://tech.michaelaltfield.net/2020/07/18/sphinx-rtd-github-pages-1
* https://github.com/annegentle/create-demo
EOF
# Copy the resulting html pages built from Sphinx to the sphinx-docs-github-pages branch
git add .
# Make a commit with changes and any new files
msg="Updating Docs for commit ${GITHUB_SHA} made on `date -d"@${SOURCE_DATE_EPOCH}" --iso-8601=seconds` from ${GITHUB_REF} by ${GITHUB_ACTOR}"
git commit -am "${msg}"
# overwrite the contents of the sphinx-docs-github-pages branch on our github.com repo
git push deploy sphinx-docs-github-pages --force
popd # return to main repo sandbox root
# exit cleanly
exit 0

57
docs/conf.py Normal file
View File

@@ -0,0 +1,57 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import datetime
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os
import sys
import furo # noqa
sys.path.insert(0, os.path.abspath("../"))
# -- Project information -----------------------------------------------------
year = datetime.datetime.now().year
project = "PrimAITE"
copyright = f"Copyright (C) Defence Science and Technology Laboratory UK 2021 - {year}"
author = "Defence Science and Technology Laboratory UK"
# The short Major.Minor.Build version
with open("../src/primaite/VERSION", "r") as file:
version = file.readline()
# The full version, including alpha/beta/rc tags
release = version
html_title = f"{project} v{release} docs"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.autodoc", # Core Sphinx library for auto html doc generation from docstrings
"sphinx.ext.autosummary", # Create summary tables for modules/classes/methods etc
"sphinx.ext.intersphinx", # Link to other project's documentation (see mapping below)
"sphinx.ext.viewcode", # Add a link to the Python source code for classes, functions etc.
"sphinx.ext.todo",
"sphinx_copybutton", # Adds a copy button to code blocks
]
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "furo"
html_static_path = ["_static"]

119
docs/index.rst Normal file
View File

@@ -0,0 +1,119 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Welcome to PrimAITE's documentation
====================================
What is PrimAITE?
-----------------
Overview
^^^^^^^^
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for the purposes of training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment, which includes:
- The ability to model a relevant platform / system context;
- Modelling an adversarial agent that the defensive agent can be trained and evaluated against;
- The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, operating systems, services and traffic loading on links;
- Modelling background pattern-of-life;
- Operates at machine-speed to enable fast training cycles.
Features
^^^^^^^^
PrimAITE incorporates the following features:
- Highly configurable (via YAML files) to provide the means to model a variety of platform / system laydowns, mission profiles and adversarial attack scenarios;
- A Reinforcement Learning (RL) reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure mission success;
- Provision of logging to support AI performance / effectiveness assessment;
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life, adversarial behaviour and mission data (on a sliding scale of criticality);
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
- Application of traffic to the links of the platform / system laydown adheres to the ACL ruleset;
- Presents both an OpenAI gym and Ray RLLib interface to the environment, allowing integration with any compliant defensive agents;
- Allows for the saving and loading of trained defensive agents;
- Stochastic adversarial agent behaviour;
- Full capture of discrete logs relating to agent training or evaluation (system state, agent actions taken, instantaneous and average reward for every step of every episode);
- Distinct control over running a training and / or evaluation session;
- NetworkX provides laydown visualisation capability.
Architecture
^^^^^^^^^^^^
PrimAITE is a Python application and is therefore Operating System agnostic. The OpenAI gym and Ray RLLib frameworks are employed to provide an interface and source for AI agents. Configuration of PrimAITE is achieved via included YAML files which support full control over the platform / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count. NetworkX based nodes and links host Python classes to present attributes and methods, and hence a more representative platform / system can be modelled within the simulation.
Training & Evaluation Capability
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its OpenAI Gym and RLLib compliant interface. Scenarios can be constructed to reflect platform / system laydowns consisting of any configuration of nodes (e.g. PCs, servers, switches etc.) and network links between them. All nodes can be configured to model services (and their status) and the traffic loading between them over the network links. Traffic loading is broken down into a per service granularity, relating directly to a protocol (e.g. Service A would be configured as a TCP service, and TCP traffic then flows between instances of Service A under the direction of a tailored IER). Highlights of PrimAITEs training and evaluation capability are:
- The scenario is not bound to a representation of any platform, system, technology or mission;
- Fully configurable (network / system laydown, IERs, node pattern-of-life, ACL, number of episodes, steps per episode) and repeatable to suit the requirements of AI agents;
- Can integrate with any OpenAI Gym or RLLib compliant AI agent.
Use of PrimAITE default scenarios within ARCD is supported by a “Use Case Profile” tailored to the scenario.
AI Assessment Capability
^^^^^^^^^^^^^^^^^^^^^^^^
PrimAITE includes the capability to support in-depth assessment of cyber defence AI by outputting logs of the environment state and AI behaviour throughout both training and evaluation sessions. These logs include the following data:
- Timestamp;
- Episode and step number;
- Agent identifier;
- Observation space;
- Action taken (by defensive AI);
- Reward value.
Logs are available in CSV format and provide coverage of the above data for every step of every episode.
What is PrimAITE built with
---------------------------
* `OpenAI's Gym <https://gym.openai.com/>`_ is used as the basis for AI blue agent interaction with the PrimAITE environment
* `Networkx <https://github.com/networkx/networkx>`_ is used as the underlying data structure used for the PrimAITE environment
* `Stable Baselines 3 <https://github.com/DLR-RM/stable-baselines3>`_ is used as a default source of RL algorithms (although PrimAITE is not limited to SB3 agents)
* `Ray RLlib <https://github.com/ray-project/ray>`_ is used as an additional source of RL algorithms
* `Typer <https://github.com/tiangolo/typer>`_ is used for building CLIs (Command Line Interface applications)
* `Jupyterlab <https://github.com/jupyterlab/jupyterlab>`_ is used as an extensible environment for interactive and reproducible computing, based on the Jupyter Notebook Architecture
* `Platformdirs <https://github.com/platformdirs/platformdirs>`_ is used for finding the right location to store user data and configuration but varies per platform
* `Plotly <https://github.com/plotly/plotly.py>`_ is used for building high level charts
Getting Started with PrimAITE
-----------------------------
Head over to the :ref:`getting-started` page to install and setup PrimAITE!
.. toctree::
:maxdepth: 8
:caption: Contents:
:hidden:
source/getting_started
source/about
source/config
source/primaite_session
source/custom_agent
PrimAITE API <source/_autosummary/primaite>
PrimAITE Tests <source/_autosummary/tests>
source/dependencies
source/glossary
source/migration_1.2_-_2.0
.. TODO: Add project links once public repo has been created
.. toctree::
:caption: Project Links:
:hidden:
Code <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE>
Issues <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues>
Pull Requests <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/pulls>
Discussions <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/discussions>

58
docs/make.bat Normal file
View File

@@ -0,0 +1,58 @@
@ECHO OFF
setlocal EnableDelayedExpansion
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set AUTOSUMMARYDIR="%cd%\source\_autosummary\"
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
REM delete autosummary if it exists
IF EXIST %AUTOSUMMARYDIR% (
echo deleting %AUTOSUMMARYDIR%
RMDIR %AUTOSUMMARYDIR% /s /q
)
REM print the YT licenses
set LICENSEBUILD=pip-licenses --format=rst --with-urls
set DEPS="%cd%\source\primaite-dependencies.rst"
%LICENSEBUILD% --output-file=%DEPS%
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:clean
IF EXIST %AUTOSUMMARYDIR% (
echo deleting %AUTOSUMMARYDIR%
RMDIR %AUTOSUMMARYDIR% /s /q
)
:end
popd

414
docs/source/about.rst Normal file
View File

@@ -0,0 +1,414 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _about:
About PrimAITE
==============
Features
********
PrimAITE provides the following features:
* A flexible network / system laydown based on the Python networkx framework
* Nodes and links (edges) host Python classes in order to present attributes and methods (and hence, a more representative model of a platform / system)
* A 'green agent' Information Exchange Requirement (IER) function allows the representation of traffic (protocols and loading) on any / all links. Application of IERs is based on the status of node operating systems and services
* A 'green agent' node Pattern-of-Life (PoL) function allows the representation of core behaviours on nodes (e.g. changing the Hardware state, Software State, Service state, or File System state)
* An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP, destination IP, protocol and port). Application of IERs adheres to any ACL restrictions
* Presents an OpenAI Gym interface to the environment, allowing integration with any OpenAI Gym compliant defensive agents
* Red agent activity based on 'red' IERs and 'red' PoL
* Defined reward function for use with RL agents (based on nodes status, and green / red IER success)
* Fully configurable (network / system laydown, IERs, node PoL, ACL, episode step period, episode max steps) and repeatable to suit the training requirements of agents. Therefore, not bound to a representation of any particular platform, system or technology
* Full capture of discrete metrics relating to agent training (full system state, agent actions taken, average reward)
* Networkx provides laydown visualisation capability
Architecture - Nodes and Links
******************************
**Nodes**
An inheritance model has been adopted in order to model nodes. All nodes have the following base attributes (Class: Node):
* ID
* Name
* Type (e.g. computer, switch, RTU - enumeration)
* Priority (P1, P2, P3, P4 or P5 - enumeration)
* Hardware State (ON, OFF, RESETTING, SHUTTING_DOWN, BOOTING - enumeration)
Active Nodes also have the following attributes (Class: Active Node):
* IP Address
* Software State (GOOD, PATCHING, COMPROMISED - enumeration)
* File System State (GOOD, CORRUPT, DESTROYED, REPAIRING, RESTORING - enumeration)
Service Nodes also have the following attributes (Class: Service Node):
* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type)
* Service state (GOOD, PATCHING, COMPROMISED, OVERWHELMED - enumeration)
Passive Nodes are currently not used (but may be employed for non IP-based components such as machinery actuators in future releases).
**Links**
Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality. Links include the following attributes:
* ID
* Name
* Bandwidth (bits/s)
* Source node ID
* Destination node ID
* Protocol list (containing the loading of protocols currently running on the link)
When the simulation runs, IERs are applied to the links in order to model traffic loading, individually assigned to each protocol. This allows green (background) and red agent behaviour to be modelled, and defensive agents to identify suspicious traffic patterns at a protocol / traffic loading level of fidelity.
Information Exchange Requirements (IERs)
****************************************
PrimAITE adopts the concept of Information Exchange Requirements (IERs) to model both green agent (background) and red agent (adversary) behaviour. IERs are used to initiate modelling of traffic loading on the network, and have the following attributes:
* ID
* Start step (i.e. which step in the training episode should the IER start)
* End step (i.e. which step in the training episode should the IER end)
* Source node ID
* Destination node ID
* Load (bits/s)
* Protocol
* Port
* Running status (i.e. on / off)
The application of green agent IERs between a source and destination follows a number of rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational, and is the service (protocol / port) associated with the IER (a) present on that node and (b) already in a compromised state
3. Is the destination node operational, and is the service (protocol / port) associated with the IER present on that node
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
Assuming the rules pass, the IER is applied to all relevant links (based on use of OSPF) between source and destination.
Node Pattern-of-Life
********************
Every node can be impacted (i.e. have a status change applied to it) by either green agent pattern-of-life or red agent pattern-of-life. This is distinct from IERs, and allows for attacks (and defence) to be modelled purely within the confines of a node.
The status changes that can be made to a node are as follows:
* All Nodes:
* Hardware State:
* ON
* OFF
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
* BOOTING
* SHUTTING_DOWN
* Active Nodes and Service Nodes:
* Software State:
* GOOD
* PATCHING - when a status of patching is entered, the node will automatically exit this state after a number of steps (as defined by the osPatchingDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* File System State:
* GOOD
* CORRUPT (can be resolved by repair or restore)
* DESTROYED (can be resolved by restore only)
* REPAIRING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRepairingLimit configuration item) after which it returns to a GOOD state
* RESTORING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRestoringLimit configuration item) after which it returns to a GOOD state
* Service Nodes only:
* Service State (for any associated service):
* GOOD
* PATCHING - when a status of patching is entered, the service will automatically exit this state after a number of steps (as defined by the servicePatchingDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* OVERWHELMED
Red agent pattern-of-life has an additional feature not found in the green pattern-of-life. This is the ability to influence the state of the attributes of a node via a number of different conditions:
* DIRECT:
The pattern-of-life described by the configuration file item will be applied regardless of any other conditions in the network. This is particularly useful for direct red agent entry into the network.
* IER:
The pattern-of-life described by the configuration file item will be applied to the service on the node, only if there is an IER of the same protocol / service type incoming at the specified timestep.
* SERVICE:
The pattern-of-life described by the configuration file item will be applied to the node based on the state of a service. The service can either be on the same node, or a different node within the network.
Access Control List modelling
*****************************
An Access Control List (ACL) is modelled to provide the means to manage traffic flows in the system. This will allow defensive agents the means to turn on / off rules, or potentially create new rules, to counter an attack.
The ACL follows a standard network firewall format. For example:
.. list-table:: ACL example
:widths: 25 25 25 25 25
:header-rows: 1
* - Permission
- Source IP
- Dest IP
- Protocol
- Port
* - DENY
- 192.168.1.2
- 192.168.1.3
- HTTPS
- 443
* - ALLOW
- 192.168.1.4
- ANY
- SMTP
- 25
* - DENY
- ANY
- 192.168.1.5
- ANY
- ANY
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry.
Observation Spaces
******************
The observation space provides the blue agent with information about the current status of nodes and links.
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationsHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
NodeLinkTable component
-----------------------
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
An example observation space is provided below:
.. list-table:: Observation Space example
:widths: 25 25 25 25 25 25 25
:header-rows: 1
* -
- ID
- Hardware State
- Software State
- File System State
- Service / Protocol A
- Service / Protocol B
* - Node A
- 1
- 1
- 1
- 1
- 1
- 1
* - Node B
- 2
- 1
- 3
- 1
- 1
- 1
* - Node C
- 3
- 2
- 1
- 1
- 3
- 2
* - Link 1
- 5
- 0
- 0
- 0
- 0
- 10000
* - Link 2
- 6
- 0
- 0
- 0
- 0
- 10000
* - Link 3
- 7
- 0
- 0
- 0
- 5000
- 0
For the nodes, the following values are represented:
.. code-block::
[
ID
Hardware State (1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
Operating System State (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
File System State (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
Service1/Protocol1 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
Service2/Protocol2 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
(Note that each service available in the network is provided as a column, although not all nodes may utilise all services)
For the links, the following statuses are represented:
.. code-block::
[
ID
Hardware State (0=not applicable)
Operating System State (0=not applicable)
File System State (0=not applicable)
Service1/Protocol1 state (Traffic load from this protocol on this link)
Service2/Protocol2 state (Traffic load from this protocol on this link)
]
NodeStatus component
----------------------
This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states.
The example above would have the following structure:
.. code-block::
[
node1_info
node2_info
node3_info
]
Each ``node_info`` contains the following:
.. code-block::
[
hardware_state (0=none, 1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
software_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
file_system_state (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
service1_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
.. code-block::
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4])
.. note::
NodeStatus observation component provides information only about nodes. Links are not considered.
LinkTrafficLevels
-----------------
This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value.
There are two configurable parameters:
* ``quantisation_levels`` determines how many discrete bins to use for converting the continuous traffic value to discrete (default is 5).
* ``combine_service_traffic`` determines whether to separately output traffic use for each network protocol or whether to combine them into an overall value for the link. (default is ``True``)
For example, with default parameters and a network with three links, the structure of this component would be:
.. code-block::
[
link1_status
link2_status
link3_status
]
Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth.
.. code-block::
0 = No traffic (0%)
1 = low traffic (1%-33%)
2 = medium traffic (33%-66%)
3 = high traffic (66%-99%)
4 = max traffic/ overwhelmed (100%)
Using ``gym`` notation, the shape of the obs space is: ``gym.spaces.MultiDiscrete([5,5,5])``.
Action Spaces
**************
The action space available to the blue agent comes in two types:
1. Node-based
2. Access Control List
3. Any (Agent can take both node-based and ACL-based actions)
The choice of action space used during a training session is determined in the config_[name].yaml file.
**Node-Based**
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, num nodes] - Node ID (0 = nothing, node ID)
* [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state)
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
**Access Control List**
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
* [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
**ANY**
The agent is able to carry out both **Node-Based** and **Access Control List** operations.
This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above.
Rewards
*******
A reward value is presented back to the blue agent on the conclusion of every step. The reward value is calculated via two methods which combine to give the total value:
1. Node and service status
2. IER status
**Node and service status**
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values.
**IER status**
On every step, the full IER set is examined to determine whether green and red agent IERs are being permitted to run. Any red agent IERs running incur a penalty; any green agent
IERs not permitted to run also incur a penalty. See :ref:`config` for details of reward values.
Future Enhancements
*******************
The PrimAITE project has an ambition to include the following enhancements in future releases:
* Integration with a suitable standardised framework to allow multi-agent integration
* Integration with external threat emulation tools, either using off-line data, or integrating at runtime

489
docs/source/config.rst Normal file
View File

@@ -0,0 +1,489 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _config:
The Config Files Explained
==========================
PrimAITE uses two configuration files for its operation:
* **The Training Config**
Used to define the top-level settings of the PrimAITE environment, the reward values, and the session that is to be run.
* **The Lay Down Config**
Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules.
Training Config:
*******************
The Training Config file consists of the following attributes:
**Generic Config Values**
* **agent_framework** [enum]
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
* NONE - Where a user developed agent is to be used
* SB3 - Stable Baselines3
* RLLIB - Ray RLlib.
* **agent_identifier**
This identifies the agent to use for the session. Select from one of the following:
* A2C - Advantage Actor Critic
* PPO - Proximal Policy Optimization
* HARDCODED - A custom built deterministic agent
* RANDOM - A Stochastic random agent
* **random_red_agent** [bool]
Determines if the session should be run with a random red agent
* **action_type** [enum]
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
* **OBSERVATION_SPACE** [dict]
Allows for user to configure observation space by combining one or more observation components. List of available
components is in :py:mod:`primaite.environment.observations`.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
This example illustrates the correct format for the observation space config item
.. code-block:: yaml
observation_space:
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
options:
combine_service_traffic : False
quantisation_levels: 99
Currently available components are:
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
* ``quantisation_levels`` - how many discrete bandwidth usage levels to use for encoding. This can be an integer equal to or greater than 3.
The other configurable item is ``flatten`` which is false by default. When set to true, the observation space is flattened (turned into a 1-D vector). You should use this if your RL agent does not natively support observation space types like ``gym.Spaces.Tuple``.
* **num_train_episodes** [int]
This defines the number of episodes that the agent will train for.
* **num_train_steps** [int]
Determines the number of steps to run in each episode of the training session.
* **num_eval_episodes** [int]
This defines the number of episodes that the agent will be evaluated over.
* **num_eval_steps** [int]
Determines the number of steps to run in each episode of the evaluation session.
* **time_delay** [int]
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
* **session_type** [text]
Type of session to be run (TRAINING, EVALUATION, or BOTH)
* **load_agent** [bool]
Determine whether to load an agent from file
* **agent_load_file** [text]
File path and file name of agent if you're loading one in
* **observation_space_high_value** [int]
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
* **implicit_acl_rule** [str]
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
* **max_number_acl_rules** [int]
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
**Reward-Based Config Values**
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
* **Generic [all_ok]** [float]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [float]
The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [float]
The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [float]
The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [float]
The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [float]
The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [float]
The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [float]
The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [float]
The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [float]
The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [float]
The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [float]
The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [float]
The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [float]
The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [float]
The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [float]
The score to give when the state should be repairing, but is good
* **Node File System State [good_should_be_restoring]** [float]
The score to give when the state should be restoring, but is good
* **Node File System State [good_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is good
* **Node File System State [good_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is good
* **Node File System State [repairing_should_be_good]** [float]
The score to give when the state should be good, but is repairing
* **Node File System State [repairing_should_be_restoring]** [float]
The score to give when the state should be restoring, but is repairing
* **Node File System State [repairing_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is repairing
* **Node File System State [repairing_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is repairing
* **Node File System State [repairing]** [float]
The score to give when the state is repairing
* **Node File System State [restoring_should_be_good]** [float]
The score to give when the state should be good, but is restoring
* **Node File System State [restoring_should_be_repairing]** [float]
The score to give when the state should be repairing, but is restoring
* **Node File System State [restoring_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is restoring
* **Node File System State [restoring_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is restoring
* **Node File System State [restoring]** [float]
The score to give when the state is restoring
* **Node File System State [corrupt_should_be_good]** [float]
The score to give when the state should be good, but is corrupt
* **Node File System State [corrupt_should_be_repairing]** [float]
The score to give when the state should be repairing, but is corrupt
* **Node File System State [corrupt_should_be_restoring]** [float]
The score to give when the state should be restoring, but is corrupt
* **Node File System State [corrupt_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is corrupt
* **Node File System State [corrupt]** [float]
The score to give when the state is corrupt
* **Node File System State [destroyed_should_be_good]** [float]
The score to give when the state should be good, but is destroyed
* **Node File System State [destroyed_should_be_repairing]** [float]
The score to give when the state should be repairing, but is destroyed
* **Node File System State [destroyed_should_be_restoring]** [float]
The score to give when the state should be restoring, but is destroyed
* **Node File System State [destroyed_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is destroyed
* **Node File System State [destroyed]** [float]
The score to give when the state is destroyed
* **Node File System State [scanning]** [float]
The score to give when the state is scanning
* **IER Status [red_ier_running]** [float]
The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [float]
The score to give when a green agent IER is prevented from running
**Patching / Reset Durations**
* **os_patching_duration** [int]
The number of steps to take when patching an Operating System
* **node_reset_duration** [int]
The number of steps to take when resetting a node's hardware state
* **service_patching_duration** [int]
The number of steps to take when patching a service
* **file_system_repairing_limit** [int]:
The number of steps to take when repairing the file system
* **file_system_restoring_limit** [int]
The number of steps to take when restoring the file system
* **file_system_scanning_limit** [int]
The number of steps to take when scanning the file system
* **deterministic** [bool]
Set to true if the agent evaluation should be deterministic. Default is ``False``
* **seed** [int]
Seed used in the randomisation in agent training. Default is ``None``
The Lay Down Config
*******************
The lay down config file consists of the following attributes:
* **itemType: STEPS** [int]
* **item_type: PORTS** [int]
Provides a list of ports modelled in this session
* **item_type: SERVICES** [freetext]
Provides a list of services modelled in this session
* **item_type: NODE**
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
* **node_class** [enum]: Relates to the base type of the node. Can be SERVICE, ACTIVE or PASSIVE. PASSIVE nodes do not have an operating system or services. ACTIVE nodes have an operating system, but no services. SERVICE nodes have both an operating system and one or more services
* **node_type** [enum]: Relates to the component type. Can be one of CCTV, SWITCH, COMPUTER, LINK, MONITOR, PRINTER, LOP, RTU, ACTUATOR or SERVER
* **priority** [enum]: Provides a priority for each node. Can be one of P1, P2, P3, P4 or P5 (which P1 being the highest)
* **hardware_state** [enum]: The initial hardware state of the node. Can be one of ON, OFF or RESETTING
* **ip_address** [IP address]: The IP address of the component in format xxx.xxx.xxx.xxx
* **software_state** [enum]: The intial state of the node operating system. Can be GOOD, PATCHING or COMPROMISED
* **file_system_state** [enum]: The initial state of the node file system. Can be GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING
* **services**: For each service associated with the node:
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
* **item_type: LINK**
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
* **bandwidth** [int]: The bandwidth (in bits/s) of the link
* **source** [int]: The ID of the source node
* **destination** [int]: The ID of the destination node
* **item_type: GREEN_IER**
Defines a green agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin
* **end_step** [int]: The end step (in the episode) for this IER to finish
* **load** [int]: The load (in bits/s) for this IER to apply to links
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
* **source** [int]: The ID of the source node
* **destination** [int]: The ID of the destination node
* **mission_criticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
* **item_type: RED_IER**
Defines a red agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin
* **end_step** [int]: The end step (in the episode) for this IER to finish
* **load** [int]: The load (in bits/s) for this IER to apply to links
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
* **source** [int]: The ID of the source node
* **destination** [int]: The ID of the destination node
* **mission_criticality** [enum]: Not currently used. Default to 0
* **item_type: GREEN_POL**
Defines a green agent pattern-of-life instruction. It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this PoL to begin
* **end_step** [int]: Not currently used. Default to same as start step
* **nodeId** [int]: The ID of the node to apply the PoL to
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
* **item_type: RED_POL**
Defines a red agent pattern-of-life instruction. It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this PoL to begin
* **end_step** [int]: Not currently used. Default to same as start step
* **targetNodeId** [int]: The ID of the node to apply the PoL to
* **initiator** [enum]: What initiates the PoL. Can be DIRECT, IER or SERVICE
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
* **state** [enum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state) or GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING (for file system state)
* **sourceNodeId** [int] The ID of the source node containing the service to check (used for SERVICE initiator)
* **sourceNodeService** [freetext]: The service on the source node to check (used for SERVICE initiator). Must match a value in the services list for this node
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
* **item_type: ACL_RULE**
Defines an initial Access Control List (ACL) rule. It should consist of:
* **id** [int]: Unique ID for this YAML item
* **permission** [enum]: Defines either an allow or deny rule. Value must be either DENY or ALLOW
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).

View File

@@ -0,0 +1,142 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Custom Agents
=============
Integrating a user defined blue agent
*************************************
.. note::
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
Below is a barebones example of a custom agent implementation:
.. code:: python
# src/primaite/agents/my_custom_agent.py
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
class CustomAgent(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
assert self._training_config.agent_framework == AgentFramework.CUSTOM
assert self._training_config.agent_identifier == AgentIdentifier.MY_AGENT
self._setup()
def _setup(self):
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._agent = ... # your code to setup agent
def _save_checkpoint(self):
checkpoint_num = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
if checkpoint_num:
save_checkpoint = episode_count % checkpoint_num == 0
# saves checkpoint if the episode count is not 0 and save_checkpoint flag was set to true
if episode_count and save_checkpoint:
...
# your code to save checkpoint goes here.
# The path should start with self.checkpoints_path and include the episode number.
def learn(self):
...
# call your agent's learning function here.
super().learn() # this will finalise learning and output session metadata
self.save()
def evaluate(self):
...
# call your agent's evaluation function here.
self._env.close()
super().evaluate()
def _get_latest_checkpoint(self):
...
# Load an agent from file.
@classmethod
def load(cls, path):
...
# Create a CustomAgent object which loads model weights from file.
def save(self):
...
# Call your agent's function that saves it to a file
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` and :py:mod:`primaite.common.enums` to capture your new agent identifiers.
.. code-block:: python
:emphasize-lines: 17, 18
# src/primaite/common/enums.py
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
CUSTOM_AGENT = 7
"Your custom agent"
.. code-block:: python
:emphasize-lines: 3, 11, 12
# src/primaite_session.py
from primaite.agents.my_custom_agent import CustomAgent
# ...
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.CUSTOM_AGENT:
self._agent_session = CustomAgent(self._training_config_path, self._lay_down_config_path)
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
Finally, specify your agent in your training config.
.. code-block:: yaml
# ~/primaite/2.0.0/config/path/to/your/config_main.yaml
# Training Config File
agent_framework: CUSTOM
agent_identifier: CUSTOM_AGENT
random_red_agent: False
# ...
Now you can :ref:`run a primaite session<run a primaite session>` with your custom agent by passing in the custom ``config_main``.

View File

@@ -0,0 +1,14 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. role:: raw-html(raw)
:format: html
Dependencies
============
PrimAITE Dependencies
---------------------
.. include:: primaite-dependencies.rst

View File

@@ -0,0 +1,149 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _getting-started:
Getting Started
===============
**Getting Started with PrimAITE**
Pre-Requisites
In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.10 installed. If you don't already have it, this is how to install it:
.. code-block:: bash
:caption: Unix
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt install python3.10
sudo apt-get install python3-pip
sudo apt-get install python3-venv
.. code-block:: text
:caption: Windows (Powershell)
- Manual install from: https://www.python.org/downloads/release/python-31011/
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
Install PrimAITE
****************
1. Create a primaite directory in your home directory:
.. code-block:: bash
:caption: Unix
mkdir ~/primaite/2.0.0
.. code-block:: powershell
:caption: Windows (Powershell)
mkdir ~\primaite\2.0.0
2. Navigate to the primaite directory and create a new python virtual environment (venv)
.. code-block:: bash
:caption: Unix
cd ~/primaite/2.0.0
python3 -m venv .venv
.. code-block:: powershell
:caption: Windows (Powershell)
cd ~\primaite\2.0.0
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
3. Activate the venv
.. code-block:: bash
:caption: Unix
source .venv/bin/activate
.. code-block:: powershell
:caption: Windows (Powershell)
.\.venv\Scripts\activate
4. Install PrimAITE using pip from PyPi
.. code-block:: bash
:caption: Unix
pip install primaite
.. code-block:: powershell
:caption: Windows (Powershell)
pip install primaite
5. Perform the PrimAITE setup
.. code-block:: bash
:caption: Unix
primaite setup
.. code-block:: powershell
:caption: Windows (Powershell)
primaite setup
Clone & Install PrimAITE for Development
****************************************
To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
of your choice:
.. TODO:: Add repo path once we know what it is
.. code-block:: bash
git clone <repo path>
cd primaite
Create and activate your Python virtual environment (venv)
.. code-block:: bash
:caption: Unix
python3 -m venv venv
source venv/bin/activate
.. code-block:: powershell
:caption: Windows (Powershell)
python3 -m venv venv
.\venv\Scripts\activate
Install PrimAITE with the dev extra
.. code-block:: bash
:caption: Unix
pip install -e .[dev]
.. code-block:: powershell
:caption: Windows (Powershell)
pip install -e .[dev]
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).

81
docs/source/glossary.rst Normal file
View File

@@ -0,0 +1,81 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Glossary
=============
.. glossary::
:sorted:
Network
The network in primaite is a logical representation of a computer network containing :term:`Nodes<Node>` and :term:`Links<Link>`.
Node
A Node represents a network endpoint. For example a computer, server, switch, or an actuator.
Link
A Link represents the connection between two Nodes. For example, a physical wire between a computer and a switch or a wireless connection.
Protocol
Protocols are used by links to separate different types of network traffic. Common examples would be HTTP, TCP, and UDP.
Service
A service represents a piece of software that is installed on a node, such as a web server or a database.
Access Control List
PrimAITE blocks or allows certain traffic on the network by simulating firewall rules, which are defined in the Access Control List.
Agent
An agent is a representation of a user of the network. Typically this would be a user that is using one of the computer nodes, though it could be an autonomous agent.
Green agent
Simulates typical benign activity on the network, such as real users using computers and servers.
Red Agent
An agent that is aiming to attack the network in some way, for example by executing a Denial-Of-Service attack or stealing data.
Blue Agent
A defensive agent that protects the network from Red Agent attacks to minimise disruption to green agents and protect data.
Information Exchange Requirement (IER)
Simulates network traffic by sending data from one network node to another via links for a specified amount of time. IERs can be part of green agent behaviour or red agent behaviour. PrimAITE can be configured to apply a penalty for green agents' IERs being blocked and a reward for red agents' IERs being blocked.
Pattern-of-Life (PoL)
PoLs allow agents to change the current hardware, OS, file system, or service statuses of nodes during the course of an episode. For example, a green agent may restart a server node to represent scheduled maintainance. A red agent's Pattern-of-Life can be used to attack nodes by changing their states to CORRUPTED or COMPROMISED.
Reward
The reward is a single number used by the blue agent to understand whether it's performing well or poorly. RL agents change their behaviour in an attempt to increase the expected reward each episode. The reward is generated based on the current states of the environment / :term:`reference environment` and is impacted positively by things like green IERS running successfully and negatively by things like nodes being compromised.
Observation
An observation is a representation of the current state of the environment that is given to the learning agent so it can decide on which action to perform. If the environment is 'fully observable', the observation contains information about every possible aspect of the environment. More commonly, the environment is 'partially observable' which means the learning agent has to make decisions without knowing every detail of the current environment state.
Action
The learning agent decides on an action to take on every step in the simulation. The action has the chance to positively or negatively impact the environment state. Over time, the agent aims to learn which actions to take when to maximise the expected reward.
Training
During training, an RL agent is placed in the simulated network and it learns which actions to take in which scenarios to obtain maximum reward.
Evaluation
During evaluation, an RL agent acts on the simulated network but it is not allowed to update it's behaviour. Evaluation is used to assess how successful agents are at defending the network.
Step
The agents can only act in the environment at discrete intervals. The time step is the basic unit of time in the simulation. At each step, the RL agent has an opportunity to observe the state of the environment and decide an action. Steps are also used for updating states for time-dependent activities such as rebooting a node.
Episode
When an episode starts, the network simulation is reset to an initial state. The agents take actions on each step of the episode until it reaches a terminal state, which usually happens after a predetermined number of steps. After the terminal state is reached, a new episode starts and the RL agent has another opportunity to protect the network.
Reference environment
While the network simulation is unfolding, a parallel simulation takes place which is identical to the main one except that blue and red agent actions are not applied. This reference environment essentially shows what would be happening to the network if there had been no cyberattack or defense. The reference environment is used to calculate rewards.
Transaction
PrimAITE records the decisions of the learning agent by saving its observation, action, and reward at every time step. During each session, this data is saved to disk to allow for full inspection.
Laydown
The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours.
Gym
PrimAITE uses the Gym reinforcement learning framework API to create a training environment and interface with RL agents. Gym defines a common way of creating observations, actions, and rewards.
User app home
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\Users\<username>\primaite\<version>` on Windows.

View File

@@ -0,0 +1,57 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
v1.2 to v2.0 Migration guide
============================
**1. Installing PrimAITE**
Like before, you can install primaite from the repository by running ``pip install -e .``. But, there is now an additional setup step which does several things, like setting up user directories, copy default configs and notebooks, etc. Once you have installed PrimAITE to your virtual environment, run this command to finalise setup.
.. code-block:: bash
primaite setup
**2. Running a training session**
In version 1.2 of PrimAITE, the main entry point for training or evaluating agents was the ``src/primaite/main.py`` file. v2.0.0 introduced managed 'sessions' which are responsible for reading configuration files, performing training, and writing outputs.
``main.py`` file still runs a training session but it now uses the new `PrimaiteSession`, and it now requires you to provide the path to your config files.
.. code-block:: bash
python src/primaite/main.py --tc path/to/training-config.yaml --ldc path/to/laydown-config.yaml
Alternatively, the session can be invoked via the commandline by running:
.. code-block:: bash
primaite session --tc path/to/training-config.yaml --ldc path/to/laydown-config.yaml
**3. Location of configs**
In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/2.0.0/config``. On Windows, this is stored in ``C:\Users\<your username>\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here.
**4. Contents of configs**
Some things that were previously part of the laydown config are now part of the traning config.
* Actions
If you have custom configs which use these, you will need to adapt them by moving the configuration from the laydown config to the training config.
Also, there are new configurable items in the training config:
* Observations
* Agent framework
* Agent
* Deep learning framework
* random red agents
* seed
* deterministic
* hard coded agent view
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.

View File

@@ -0,0 +1,212 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _run a primaite session:
Run a PrimAITE Session
======================
Run
---
A PrimAITE session can be ran either with the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to training config yaml file>
lay_down_config = <path to lay down config yaml file>
run(training_config, lay_down_config)
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``).
The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``.
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options.
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to legacy training config yaml file>
lay_down_config = <path to legacy lay down config yaml file>
run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True)
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents (training checkpoints and a final trained agent)
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **learning**
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **evaluation**
* **total_episodes** - The total number of evaluation episodes completed.
* **total_time_steps** - The total number of evaluation time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each learning and evaluation session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
**Example Session Directory Structure**
.. code-block:: text
~/
└── primaite/
└── 2.0.0/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
│ │ └── sb3ppo_10.zip
│ ├── SB3_PPO.zip
│ └── tensorboard_logs/
│ ├── PPO_1/
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
│ ├── PPO_2/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
│ ├── PPO_3/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
│ ├── PPO_4/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
│ ├── PPO_5/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
│ ├── PPO_6/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
│ ├── PPO_7/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
│ ├── PPO_8/
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
│ ├── PPO_9/
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory

88
pyproject.toml Normal file
View File

@@ -0,0 +1,88 @@
[build-system]
requires = ["setuptools", "setuptools-scm", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "primaite"
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
license = {file = "LICENSE"}
requires-python = ">=3.8, <3.11"
dynamic = ["version", "readme"]
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 5 - Production/Stable",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"Operating System :: POSIX :: Linux",
"Operating System :: Unix",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"gym==0.21.0",
"jupyterlab==3.6.1",
"kaleido==0.2.1",
"matplotlib==3.7.1",
"networkx==3.1",
"numpy==1.23.5",
"platformdirs==3.5.1",
"plotly==5.15.0",
"polars==0.18.4",
"PyYAML==6.0",
"ray[rllib]==2.2.0",
"stable-baselines3==1.6.2",
"tensorflow==2.12.0",
"typer[all]==0.9.0"
]
[tool.setuptools.dynamic]
version = {file = ["src/primaite/VERSION"]}
readme = {file = ["README.md"]}
[tool.setuptools]
package-dir = {"" = "src"}
include-package-data = true
license-files = ["LICENSE"]
[project.optional-dependencies]
dev = [
"build==0.10.0",
"flake8==6.0.0",
"furo==2023.3.27",
"gputil==1.4.0",
"pip-licenses==4.3.0",
"pre-commit==2.20.0",
"pylatex==1.4.1",
"pytest==7.2.0",
"pytest-xdist==3.3.1",
"pytest-cov==4.0.0",
"pytest-flake8==1.1.1",
"setuptools==66",
"Sphinx==6.1.3",
"sphinx-copybutton==0.5.2",
"wheel==0.38.4"
]
[project.scripts]
primaite = "primaite.cli:app"
[tool.isort]
profile = "black"
line_length = 120
force_sort_within_sections = "False"
order_by_type = "False"
[tool.black]
line-length = 120
[project.urls]
Homepage = "https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE"
Documentation = "https://Autonomous-Resilient-Cyber-Defence.github.io/PrimAITE/"
Repository = "https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE"
Changelog = "https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/blob/dev/CHANGELOG.md"

5
pytest.ini Normal file
View File

@@ -0,0 +1,5 @@
[pytest]
testpaths =
tests
markers =
env_config_paths

4
setup.cfg Normal file
View File

@@ -0,0 +1,4 @@
[metadata]
url = https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE
author = Defence Science and Technology Laboratory UK
author_email = oss@dstl.gov.uk

17
setup.py Normal file
View File

@@ -0,0 +1,17 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from setuptools import setup
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa
class bdist_wheel(_bdist_wheel): # noqa
def finalize_options(self): # noqa
super().finalize_options()
# Set to False if you need to build OS and Python specific wheels
self.root_is_pure = True # noqa
setup(
cmdclass={
"bdist_wheel": bdist_wheel,
}
)

1
src/primaite/VERSION Normal file
View File

@@ -0,0 +1 @@
2.0.0

207
src/primaite/__init__.py Normal file
View File

@@ -0,0 +1,207 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import logging
import logging.config
import shutil
import sys
from bisect import bisect
from logging import Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Any, Dict, Final, List
import pkg_resources
import yaml
from platformdirs import PlatformDirs
with open(Path(__file__).parent.resolve() / "VERSION", "r") as file:
__version__ = file.readline().strip()
class _PrimaitePaths:
"""
A Primaite paths class that leverages PlatformDirs.
The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`.
"""
def __init__(self):
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
return [k for k, v in class_items if isinstance(v, property)]
def mkdirs(self):
"""
Creates all Primaite directories.
Does this by retrieving all properties in the PrimaiteDirs class and calls each one.
"""
for p in self._get_dirs_properties():
getattr(self, p)
@property
def user_home_path(self) -> Path:
"""The PrimAITE user home path."""
path = Path.home() / "primaite" / __version__
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_sessions_path(self) -> Path:
"""The PrimAITE user sessions path."""
path = self.user_home_path / "sessions"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_config_path(self) -> Path:
"""The PrimAITE user config path."""
path = self.user_home_path / "config"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_notebooks_path(self) -> Path:
"""The PrimAITE user notebooks path."""
path = self.user_home_path / "notebooks"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_home_path(self) -> Path:
"""The PrimAITE app home path."""
path = self._dirs.user_data_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_dir_path(self) -> Path:
"""The PrimAITE app config directory path."""
path = self._dirs.user_config_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_file_path(self) -> Path:
"""The PrimAITE app config file path."""
return self.app_config_dir_path / "primaite_config.yaml"
@property
def app_log_dir_path(self) -> Path:
"""The PrimAITE app log directory path."""
if sys.platform == "win32":
path = self.app_home_path / "logs"
else:
path = self._dirs.user_log_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_log_file_path(self) -> Path:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"
def __repr__(self):
properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()])
return f"{self.__class__.__name__}({properties_str})"
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
def _host_primaite_config():
if not PRIMAITE_PATHS.app_config_file_path.exists():
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_host_primaite_config()
def _get_primaite_config() -> Dict:
config_path = PRIMAITE_PATHS.app_config_file_path
if not config_path.exists():
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
with open(config_path, "r") as file:
primaite_config = yaml.safe_load(file)
log_level_map = {
"NOTSET": logging.NOTSET,
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
return primaite_config
_PRIMAITE_CONFIG = _get_primaite_config()
class _LevelFormatter(Formatter):
"""
A custom level-specific formatter.
Credit to: https://stackoverflow.com/a/68154386
"""
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
super().__init__()
if "fmt" in kwargs:
raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one")
self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items())
def format(self, record: LogRecord) -> str:
"""Overrides ``Formatter.format``."""
idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1)
level, formatter = self.formats[idx]
return formatter.format(record)
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
}
)
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
filename=PRIMAITE_PATHS.app_log_file_path,
maxBytes=10485760, # 10MB
backupCount=9, # Max 100MB of logs
encoding="utf8",
)
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
_LOGGER = logging.getLogger(__name__)
_LOGGER.addHandler(_STREAM_HANDLER)
_LOGGER.addHandler(_FILE_HANDLER)
def getLogger(name: str) -> Logger: # noqa
"""
Get a PrimAITE logger.
:param name: The logger name. Use ``__name__``.
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
logging config.
"""
logger = logging.getLogger(name)
logger.setLevel(_PRIMAITE_CONFIG["log_level"])
return logger

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Access Control List. Models firewall functionality."""

View File

@@ -0,0 +1,198 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A class that implements the access control list implementation for the network."""
import logging
from typing import Dict, Final, List, Union
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class AccessControlList:
"""Access Control List class."""
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
"""Init."""
# Implicit ALLOW or DENY firewall spec
self.acl_implicit_permission = implicit_permission
# Implicit rule in ACL list
if self.acl_implicit_permission == RulePermissionType.DENY:
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
else:
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
# Maximum number of ACL Rules in ACL
self.max_acl_rules: int = max_acl_rules
# A list of ACL Rules
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
@property
def acl(self) -> List[Union[ACLRule, None]]:
"""Public access method for private _acl."""
return self._acl + [self.acl_implicit_rule]
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
:param _rule: The rule object to check
:type _rule: ACLRule
:param _source_ip_address: Source IP address to compare
:type _source_ip_address: str
:param _dest_ip_address: Destination IP address to compare
:type _dest_ip_address: str
:return: True if there is a match, otherwise False.
:rtype: bool
"""
if (
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
):
return True
else:
return False
def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool:
"""
Checks for rules that block a protocol / port.
Args:
_source_ip_address: the source IP address to check
_dest_ip_address: the destination IP address to check
_protocol: the protocol to check
_port: the port to check
Returns:
Indicates block if all conditions are satisfied.
"""
for rule in self.acl:
if isinstance(rule, ACLRule):
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule.get_permission() == RulePermissionType.DENY:
return True
elif rule.get_permission() == RulePermissionType.ALLOW:
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(
self,
_permission: RulePermissionType,
_source_ip: str,
_dest_ip: str,
_protocol: str,
_port: str,
_position: str,
) -> None:
"""
Adds a new rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
_source_ip: the source IP address
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
"""
try:
position_index = int(_position)
except TypeError:
_LOGGER.info(f"Position {_position} could not be converted to integer.")
return
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
# Checks position is in correct range
if self.max_acl_rules - 1 > position_index > -1:
try:
_LOGGER.info(f"Position {position_index} is valid.")
# Check to see Agent will not overwrite current ACL in ACL list
if self._acl[position_index] is None:
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
# Adds rule
self._acl[position_index] = new_rule
else:
# Cannot overwrite it
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
return
except Exception:
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
else:
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
def remove_rule(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Removes a rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
_source_ip: the source IP address
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
"""
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
delete_rule_hash = hash(rule_to_delete)
for index in range(0, len(self._acl)):
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
self._acl[index] = None
def remove_all_rules(self) -> None:
"""Removes all rules."""
for i in range(len(self._acl)):
self._acl[i] = None
def get_dictionary_hash(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> int:
"""
Produces a hash value for a rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
_source_ip: the source IP address
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
Returns:
Hash value based on rule parameters.
"""
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
return hash_value
def get_relevant_rules(
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
) -> Dict[int, ACLRule]:
"""Get all ACL rules that relate to the given arguments.
:param _source_ip_address: the source IP address to check
:param _dest_ip_address: the destination IP address to check
:param _protocol: the protocol to check
:param _port: the port to check
:return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[int, ACLRule]
"""
relevant_rules = {}
for rule in self.acl:
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
):
# There's a matching rule.
relevant_rules[self._acl.index(rule)] = rule
return relevant_rules

View File

@@ -0,0 +1,87 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType
class ACLRule:
"""Access Control List Rule class."""
def __init__(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Initialise an ACL Rule.
:param _permission: The permission (ALLOW or DENY)
:param _source_ip: The source IP address
:param _dest_ip: The destination IP address
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission: RulePermissionType = _permission
self.source_ip: str = _source_ip
self.dest_ip: str = _dest_ip
self.protocol: str = _protocol
self.port: str = _port
def __hash__(self) -> int:
"""
Override the hash function.
Returns:
Returns hash of core parameters.
"""
return hash(
(
self.permission,
self.source_ip,
self.dest_ip,
self.protocol,
self.port,
)
)
def get_permission(self) -> str:
"""
Gets the permission attribute.
Returns:
Returns permission attribute
"""
return self.permission
def get_source_ip(self) -> str:
"""
Gets the source IP address attribute.
Returns:
Returns source IP address attribute
"""
return self.source_ip
def get_dest_ip(self) -> str:
"""
Gets the desintation IP address attribute.
Returns:
Returns destination IP address attribute
"""
return self.dest_ip
def get_protocol(self) -> str:
"""
Gets the protocol attribute.
Returns:
Returns protocol attribute
"""
return self.protocol
def get_port(self) -> str:
"""
Gets the port attribute.
Returns:
Returns port attribute
"""
return self.port

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Common interface between RL agents from different libraries and PrimAITE."""

View File

@@ -0,0 +1,319 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Optional, Union
from uuid import uuid4
import primaite
from primaite import getLogger, PRIMAITE_PATHS
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER: Logger = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
class AgentSessionABC(ABC):
"""
An ABC that manages training and/or evaluation of agents in PrimAITE.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
@abstractmethod
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
Initialise an agent session from config files, or load a previous session.
If training configuration and laydown configuration are provided with a session path,
the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param session_path: directory path of the session to load
"""
# initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self.legacy_training_config = legacy_training_config
self.legacy_lay_down_config = legacy_lay_down_config
self.session_timestamp: datetime = datetime.now()
# convert session to path
if session_path is not None:
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
# load session
self.load(session_path)
else:
# set training config path
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(
self._training_config_path, legacy_file=legacy_training_config
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = str(uuid4())
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
"""The session timestamp as a string."""
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
@property
def learning_path(self) -> Path:
"""The learning outputs path."""
path = self.session_path / "learning"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def evaluation_path(self) -> Path:
"""The evaluation outputs path."""
path = self.session_path / "evaluation"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def checkpoints_path(self) -> Path:
"""The Session checkpoints path."""
path = self.learning_path / "checkpoints"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def uuid(self) -> str:
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self) -> None:
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_path`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self.uuid,
"start_datetime": self.session_timestamp.isoformat(),
"end_datetime": None,
"learning": {"total_episodes": None, "total_time_steps": None},
"evaluation": {"total_episodes": None, "total_time_steps": None},
"env": {
"training_config": self._training_config.to_dict(json_serializable=True),
"lay_down_config": self._lay_down_config,
},
}
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self) -> None:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self) -> None:
pass
@abstractmethod
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_learn:
_LOGGER.info("Finished learning")
_LOGGER.debug("Writing transactions")
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
@abstractmethod
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_evaluate:
self._update_session_metadata_file()
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self) -> None:
pass
def load(self, path: Union[str, Path]) -> None:
"""Load an agent from file."""
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
# set training config path
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
self._lay_down_config_path: Union[Path, str] = laydown_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = md_dict["uuid"]
# set the session path
self.session_path = path
"The Session path"
@property
def _saved_agent_path(self) -> Path:
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod
def save(self) -> None:
"""Save the agent."""
pass
@abstractmethod
def export(self) -> None:
"""Export the agent to transportable file format."""
pass
def close(self) -> None:
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
if learning_session:
title += "(Learning)"
path = self.learning_path / csv_file
image_path = self.learning_path / image_file
else:
title += "(Evaluation)"
path = self.evaluation_path / csv_file
image_path = self.evaluation_path / image_file
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")

View File

@@ -0,0 +1,118 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import time
from abc import abstractmethod
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls, path: Union[str, Path] = None) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,515 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from typing import Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs: np.ndarray) -> int:
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
# full view action using observation space, action
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, IER]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
:type green_iers: Dict[str, IER]
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
:rtype: Dict[str, IER]
"""
blocked_green_iers = {}
for green_ier_id, green_ier in green_iers.items():
source_node_id = green_ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = green_ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = green_ier.get_protocol() # e.g. 'TCP'
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
def get_matching_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""
Get blocking ACL rules for an IER.
.. warning::
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
blocked_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
blocked_rules[rule_key] = rule_value
return blocked_rules
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_matching_acl_rules(
self,
source_node_id: str,
dest_node_id: str,
protocol: str,
port: str,
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[int, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
:param source_node_id: Source node
:type source_node_id: str
:param dest_node_id: Destination nodes
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: str
:param port: Network port
:type port: str
:param acl: Access Control list which will be filtered
:type acl: AccessControlList
:param nodes: The environment's node directory.
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
:param services_list: List of services registered for the environment.
:type services_list: List[str]
:return: Filtered version of 'acl'
:rtype: Dict[str, ACLRule]
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
else:
source_node_address = source_node_id
if dest_node_id != "ANY":
dest_node_address = nodes[str(dest_node_id)].ip_address
else:
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
# TODO: This should throw an error because protocol is a string
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
self,
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[int, ACLRule]:
"""List ALLOW rules relating to specified nodes.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_deny_acl_rules(
self,
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[int, ACLRule]:
"""List DENY rules relating to specified nodes.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
allowed_rules[rule_key] = rule_value
return allowed_rules
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
- Which ACL rules already exist, - otherwise:
- The agent would perminently get stuck in a loop of performing the same action over and over.
(best action is to block something, but its already blocked but doesn't know this)
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
if it doesnt know what rules exist)
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
based on the observation space. Additionally, potentially in the future, once a node state
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
If a service node becomes compromised there's a decision to make - do we block that service?
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
Cons: Will block a green IER, decreasing the reward
We decide to block the service.
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
_, _, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# 1. Check if node is compromised. If so we want to block its outwards services
# a. If it is comprimised check if there's an allow rule we should delete.
# cons: might delete a multi-rule from any source node (ANY -> x)
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
found_action = False
for service_num, service_states in enumerate(s):
for x, service_state in enumerate(service_states):
if service_state == "COMPROMISED":
action_source_id = x + 1 # +1 as 0 is any
action_destination_id = "ANY"
action_protocol = service_num + 1 # +1 as 0 is any
action_port = "ANY"
allow_rules = self.get_allow_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
deny_rules = self.get_deny_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
if len(allow_rules) > 0:
# Check if there's an allow rule we should delete
rule = list(allow_rules.values())[0]
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
elif len(deny_rules) > 0:
# TODO OPTIONAL
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
# to create another
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
continue
else:
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
action_decision = "CREATE"
action_permission = "DENY"
break
if found_action:
break
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
# add an Allow rule if the green IER is being blocked.
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
# If there's a DENY rule delete it if:
# - There isn't already a deny rule
# - It doesnt allows a comprimised node to become operational.
# b. Add an ALLOW rule if:
# - There isn't already an allow rule
# - It doesnt allows a comprimised node to become operational
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
action_decision = "CREATE"
action_permission = "ALLOW"
action_source_id = int(ier.get_source_node_id())
action_destination_id = int(ier.get_dest_node_id())
action_protocol_name = ier.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = ier.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
if found_action:
action = [
action_decision,
action_permission,
action_source_id,
action_destination_id,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
else:
# If no good/useful action has been found, just perform a nothing action
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
return action
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
of previous actions taken and NO reward feedback.
We rely on randomness to select the precise action, as we want to
block all traffic originating from a compromised node, without being
able to tell:
1. Which ACL rules already exist
2. Which actions the agent has already tried.
There is a high probability that the correct rule will not be deleted
before the state becomes overwhelmed.
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block
# Bad assumption that number of protocols equals number of ports
# AND no rules exist with an ANY port
action_port = np.random.choice(list(range(1, len(s) + 1)))
action = [
action_decision,
action_permission,
action_source_ip,
action_destination_ip,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good/useful action has been found, just perform a nothing action
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, action_dict)
return nothing_action

View File

@@ -0,0 +1,125 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs: np.ndarray) -> int:
"""
Calculate a good node-based action for the blue agent to take.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, os, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# Check in order of most important states (order doesn't currently
# matter, but it probably should)
# First see if any OS states are compromised
for x, os_state in enumerate(os):
if os_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "OS"
property_action = "PATCHING"
action_service_index = 0 # does nothing isn't relevant for os
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, see if any Services are compromised
# We fix the compromised state before overwhelemd state,
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, See if any services are overwhelmed
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "OVERWHELMED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Finally, turn on any off nodes
for x, operating_state in enumerate(o):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good actions, just go with an action that wont do any harm
action_node_id = 1
action_node_property = "NONE"
property_action = "ON"
action_service_index = 0
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
return action

View File

@@ -0,0 +1,286 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
import shutil
import zipfile
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
from primaite.environment.primaite_env import Primaite
from primaite.exceptions import RLlibAgentError
_LOGGER: Logger = getLogger(__name__)
# TODO: verify type of env_config
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"],
)
# TODO: verify type hint return type
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config: Dict) -> UnifiedLogger:
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise the RLLib Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.critical(msg)
raise NotImplementedError(msg)
super().__init__(training_config_path, lay_down_config_path)
if self._training_config.session_type == SessionType.EVAL:
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
_LOGGER.critical(msg)
raise RLlibAgentError(msg)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config_class: Union[PPOConfig, A2CConfig]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: Union[PPOConfig, A2CConfig]
self._current_result: dict
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
self._train_agent = None # Required to capture the learning agent to close after eval
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self) -> None:
super()._setup()
register_env("primaite", _env_creator)
self._agent_config = self._agent_config_class()
self._agent_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_train_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
super().learn()
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent = self._agent
else:
self._agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
def _unpack_saved_agent_into_eval(self) -> Path:
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
agent_restore_path = self.evaluation_path / "agent_restore"
if agent_restore_path.exists():
shutil.rmtree(agent_restore_path)
agent_restore_path.mkdir()
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
zip_file.extractall(agent_restore_path)
return agent_restore_path
def _setup_eval(self):
self._can_learn = False
self._can_evaluate = True
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._setup_eval()
self._env: Primaite = Primaite(
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
)
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action = self._agent.compute_single_action(observation=obs, explore=False)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
# Now we're safe to close the learning agent and write the mean rewards per episode for it
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# Perform a clean-up of the unpacked agent
if (self.evaluation_path / "agent_restore").exists():
shutil.rmtree((self.evaluation_path / "agent_restore"))
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError
@classmethod
def load(cls, path: Union[str, Path]) -> RLlibAgent:
"""Load an agent from file."""
raise NotImplementedError
def save(self, overwrite_existing: bool = True) -> None:
"""Save the agent."""
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

206
src/primaite/agents/sb3.py Normal file
View File

@@ -0,0 +1,206 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
from logging import Logger
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
Initialise the SB3 Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_class: Union[PPO, A2C]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
_LOGGER.error(msg)
raise ValueError(msg)
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}"
)
self.is_eval = False
self._setup()
def _setup(self) -> None:
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
legacy_training_config=self.legacy_training_config,
legacy_lay_down_config=self.legacy_lay_down_config,
)
# check if there is a zip file that needs to be loaded
load_file = next(self.session_path.rglob("*.zip"), None)
if not load_file:
# create a new env and agent
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
# load environment values
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
self._env.total_step_count = 0
else:
# carry on from previous learning sessions
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
# load the file
self._agent = self._agent_class.load(load_file, env=self._env)
# set agent values
self._agent.verbose = self.sb3_output_verbose_level
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
super()._setup()
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env._write_av_reward_per_episode() # noqa
self.save()
self._env.close()
super().learn()
# save agent
self.save()
self._plot_av_reward_per_episode(learning_session=True)
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self._env._write_av_reward_per_episode() # noqa
self._env.close()
super().evaluate()
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -0,0 +1,59 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
class RandomAgent(HardCodedAgentSessionABC):
"""
A Random Agent.
Get a completely random action from the action space.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
return self._env.action_space.sample()
class DummyAgent(HardCodedAgentSessionABC):
"""
A Dummy Agent.
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
return 0
class DoNothingACLAgent(HardCodedAgentSessionABC):
"""
A do nothing ACL agent.
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
return nothing_action
class DoNothingNodeAgent(HardCodedAgentSessionABC):
"""
A do nothing Node agent.
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
# nothing_action should currently always be 0
return nothing_action

View File

@@ -0,0 +1,450 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from typing import Dict, List, Union
import numpy as np
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
LinkStatus,
NodeHardwareAction,
NodePOLType,
NodeSoftwareAction,
SoftwareState,
)
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
"""Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
return new_action
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def is_valid_node_action(action: List[int]) -> bool:
"""
Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
# print("node property", node_property, "\nnode action", node_action)
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action: List[int]) -> bool:
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect.
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action: List[int]) -> bool:
"""
Harsher version of valid acl actions, does not allow action.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
os_states = [SoftwareState(i).name for i in obs[:, 2]]
new_obs = [ids, operating_states, os_states]
for service in range(4, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
new_obs.append(service_states)
return new_obs
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform observation to readable format.
example
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
# Convert list of tuples to list of lists
new_obs = [list(i) for i in new_obs]
return new_obs
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
"""Convert original gym Box observation space to new multiDiscrete observation space.
:param obs: observation in the 'old' (NodeLinkTable) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:return: reformatted observation
:rtype: np.ndarray
"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
"""Convert to old observation.
Links filled with 0's as no information is included in new observation space.
example:
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
new_obs = array([[ 1, 1, 1, 1],
[ 2, 1, 1, 1],
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
:param obs: observation in the 'new' (MultiDiscrete) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:param num_links: number of links in the network, defaults to 10
:type num_links: int, optional
:param num_services: number of services on the network, defaults to 1
:type num_services: int, optional
:return: 2-d BOX observation space, in the same format as NodeLinkTable
:rtype: np.ndarray
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
# Add empty links back and add node ID back
s = np.zeros(
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
dtype=np.int64,
)
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
s[:num_nodes, 1:] = reshaped_nodes # put values back in
new_obs = s
# Add links back in
links = obs[-num_links:]
# Links will be added to the last protocol/service slot but they are not specific to that service
new_obs[num_nodes:, -1] = links
return new_obs
def describe_obs_change(
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
) -> str:
"""Build a string describing the difference between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
:param obs1: First observation
:type obs1: np.ndarray
:param obs2: Second observation
:type obs2: np.ndarray
:param num_nodes: How many nodes are in the network laydown, defaults to 10
:type num_nodes: int, optional
:param num_links: How many links are in the network laydown, defaults to 10
:type num_links: int, optional
:param num_services: How many services are configured for this scenario, defaults to 1
:type num_services: int, optional
:return: A multi-line string with a human-readable description of the difference.
:rtype: str
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
list_of_changes = []
for n, row in enumerate(obs1 - obs2):
if row.any() != 0:
relevant_changes = np.where(row != 0, obs2[n], -1)
relevant_changes[0] = obs2[n, 0] # ID is always relevant
is_link = relevant_changes[0] > num_nodes
desc = _describe_obs_change_helper(relevant_changes, is_link)
list_of_changes.append(desc)
change_string = "\n ".join(list_of_changes)
if len(list_of_changes) > 0:
change_string = "\n " + change_string
return change_string
def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str:
"""
Helper funcion to describe what has changed.
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
:param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one
row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new
status where it has changed.
:type obs_change: List[int]
:param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node.
:type is_link: bool
:return: A human-readable description of the difference between the two observation rows.
:rtype: str
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
if is_link
else HardwareState(obs_change[i]).name
if i == 1
else SoftwareState(obs_change[i]).name
for i in index_changed
]
if not is_link:
desc = f"ID {obs_change[0]}:"
for node_pol_type, state in list(zip(NodePOLTypes, states)):
desc = desc + " " + node_pol_type + " changed to " + state + "."
else:
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
return desc
def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]:
"""Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
:param action: Action in 'readable' format
:type action: List[Union[str,int]]
:return: Action with verbs encoded as ints
:rtype: List[int]
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
if action[1] == "OPERATING":
property_action = NodeHardwareAction[action[2]].value
elif action[1] == "OS" or action[1] == "SERVICE":
property_action = NodeSoftwareAction[action[2]].value
else:
property_action = 0
action_service_index = action[3]
new_action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
return new_action
def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray:
"""
Convert acl action from readable str format, to enumerated format.
:param action: ACL-based action expressed as a list of human-readable ints and strings
:type action: List[Union[int,str]]
:return: The same action but encoded to contain only integers.
:rtype: np.ndarray
"""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == "ANY":
new_action[n + 2] = 0
new_action = np.array(new_action)
return new_action
def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str:
"""Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
:param ip: The IP address of the node whose ID is required
:type ip: str
:param node_dict: The environment's node registry dictionary
:type node_dict: Dict[str,NodeUnion]
:return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip`
:rtype: str
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.ip_address
if node_ip == ip:
return node_key
def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int:
"""
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
Old_action can be either node or acl action type
:param old_action: Action expressed as a list of choices, eg. [1,1,1,0]
:type old_action: np.ndarray
:param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions.
:type action_dict: Dict[int,List]
:return: Action key correspoinding to the input `old_action`
:rtype: int
"""
for key, val in action_dict.items():
if list(val) == list(old_action):
return key
# Not all possible actions are included in dict, only valid action are
# if action is not in the dict, its an invalid action so return 0
return 0

213
src/primaite/cli.py Normal file
View File

@@ -0,0 +1,213 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Provides a CLI using Typer as an entry point."""
import logging
import os
from enum import Enum
from typing import Optional
import typer
import yaml
from typing_extensions import Annotated
from primaite import PRIMAITE_PATHS
from primaite.data_viz import PlotlyTemplate
app = typer.Typer()
@app.command()
def build_dirs() -> None:
"""Build the PrimAITE app directories."""
from primaite import PRIMAITE_PATHS
PRIMAITE_PATHS.mkdirs()
@app.command()
def reset_notebooks(overwrite: bool = True) -> None:
"""
Force a reset of the demo notebooks in the users notebooks directory.
:param overwrite: If True, will overwrite existing demo notebooks.
"""
from primaite.setup import reset_demo_notebooks
reset_demo_notebooks.run(overwrite)
@app.command()
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
"""
Print the PrimAITE log file.
:param last_n: The number of lines to print. Default value is 10.
"""
import re
from primaite import PRIMAITE_PATHS
if os.path.isfile(PRIMAITE_PATHS.app_log_file_path):
with open(PRIMAITE_PATHS.app_log_file_path) as file:
lines = file.readlines()
for line in lines[-last_n:]:
print(re.sub(r"\n*", "", line))
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
@app.command()
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
"""
View or set the PrimAITE Log Level.
To View, simply call: primaite log-level
To set, call: primaite log-level <desired log level>
For example, to set the to debug, call: primaite log-level DEBUG
"""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if level:
primaite_config["logging"]["log_level"] = level.value
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE Log Level: {level}")
else:
level = primaite_config["logging"]["log_level"]
print(f"PrimAITE Log Level: {level}")
@app.command()
def notebooks() -> None:
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
start_jupyter_session()
@app.command()
def version() -> None:
"""Get the installed PrimAITE version number."""
import primaite
print(primaite.__version__)
@app.command()
def clean_up() -> None:
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
old_installation_clean_up.run()
@app.command()
def setup(overwrite_existing: bool = True) -> None:
"""
Perform the PrimAITE first-time setup.
WARNING: All user-data will be lost.
"""
from primaite import getLogger
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
_LOGGER = getLogger(__name__)
_LOGGER.info("Performing the PrimAITE first-time setup...")
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building the PrimAITE app directories...")
PRIMAITE_PATHS.mkdirs()
_LOGGER.info("Rebuilding the demo notebooks...")
reset_demo_notebooks.run(overwrite_existing=True)
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
old_installation_clean_up.run()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
tc: Optional[str] = None,
ldc: Optional[str] = None,
load: Optional[str] = None,
legacy_tc: bool = False,
legacy_ldc: bool = False,
) -> None:
"""
Run a PrimAITE session.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
if load is not None:
# run a loaded session
run(session_path=load)
else:
# start a new session using tc and ldc
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
run(
training_config_path=tc,
lay_down_config_path=ldc,
legacy_training_config=legacy_tc,
legacy_lay_down_config=legacy_ldc,
)
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
"""
View or set the plotly template for Session plots.
To View, simply call: primaite plotly-template
To set, call: primaite plotly-template <desired template>
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
"""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:
template = primaite_config["session"]["outputs"]["plots"]["template"]
print(f"PrimAITE plotly template: {template}")

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Objects which are shared between many PrimAITE modules."""

View File

@@ -0,0 +1,8 @@
from typing import Union
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""

View File

@@ -0,0 +1,208 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Enumerations for APE."""
from enum import Enum, IntEnum
class NodeType(Enum):
"""Node type enumeration."""
CCTV = 1
SWITCH = 2
COMPUTER = 3
LINK = 4
MONITOR = 5
PRINTER = 6
LOP = 7
RTU = 8
ACTUATOR = 9
SERVER = 10
class Priority(Enum):
"""Node priority enumeration."""
P1 = 1
P2 = 2
P3 = 3
P4 = 4
P5 = 5
class HardwareState(Enum):
"""Node hardware state enumeration."""
NONE = 0
ON = 1
OFF = 2
RESETTING = 3
SHUTTING_DOWN = 4
BOOTING = 5
class SoftwareState(Enum):
"""Software or Service state enumeration."""
NONE = 0
GOOD = 1
PATCHING = 2
COMPROMISED = 3
OVERWHELMED = 4
class NodePOLType(Enum):
"""Node Pattern of Life type enumeration."""
NONE = 0
OPERATING = 1
OS = 2
SERVICE = 3
FILE = 4
class NodePOLInitiator(Enum):
"""Node Pattern of Life initiator enumeration."""
DIRECT = 1
IER = 2
SERVICE = 3
class Protocol(Enum):
"""Service protocol enumeration."""
LDAP = 0
FTP = 1
HTTPS = 2
SMTP = 3
RTP = 4
IPP = 5
TCP = 6
NONE = 7
class SessionType(Enum):
"""The type of PrimAITE Session to be run."""
TRAIN = 1
"Train an agent"
EVAL = 2
"Evaluate an agent"
TRAIN_EVAL = 3
"Train then evaluate an agent"
class AgentFramework(Enum):
"""The agent algorithm framework/package."""
CUSTOM = 0
"Custom Agent"
SB3 = 1
"Stable Baselines3"
RLLIB = 2
"Ray RLlib"
class DeepLearningFramework(Enum):
"""The deep learning framework."""
TF = "tf"
"Tensorflow"
TF2 = "tf2"
"Tensorflow 2.x"
TORCH = "torch"
"PyTorch"
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
class HardCodedAgentView(Enum):
"""The view the deterministic hard-coded agent has of the environment."""
BASIC = 1
"The current observation space only"
FULL = 2
"Full environment view with actions taken and reward feedback"
class ActionType(Enum):
"""Action type enumeration."""
NODE = 0
ACL = 1
ANY = 2
# TODO: this is not used anymore, write a ticket to delete it.
class ObservationType(Enum):
"""Observation type enumeration."""
BOX = 0
MULTIDISCRETE = 1
class FileSystemState(Enum):
"""File System State."""
GOOD = 1
CORRUPT = 2
DESTROYED = 3
REPAIRING = 4
RESTORING = 5
class NodeHardwareAction(Enum):
"""Node hardware action."""
NONE = 0
ON = 1
OFF = 2
RESET = 3
class NodeSoftwareAction(Enum):
"""Node software action."""
NONE = 0
PATCHING = 1
class LinkStatus(Enum):
"""Link traffic status."""
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
OVERLOAD = 4
class SB3OutputVerboseLevel(IntEnum):
"""The Stable Baselines3 learn/eval output verbosity level."""
NONE = 0
INFO = 1
DEBUG = 2
class RulePermissionType(Enum):
"""Any firewall rule type."""
NONE = 0
DENY = 1
ALLOW = 2

View File

@@ -0,0 +1,47 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The protocol class."""
class Protocol(object):
"""Protocol class."""
def __init__(self, _name: str) -> None:
"""
Initialise a protocol.
:param _name: The name of the protocol
:type _name: str
"""
self.name: str = _name
self.load: int = 0 # bps
def get_name(self) -> str:
"""
Gets the protocol name.
Returns:
The protocol name
"""
return self.name
def get_load(self) -> int:
"""
Gets the protocol load.
Returns:
The protocol load (bps)
"""
return self.load
def add_load(self, _load: int) -> None:
"""
Adds load to the protocol.
Args:
_load: The load to add
"""
self.load += _load
def clear_load(self) -> None:
"""Clears the load on this protocol."""
self.load = 0

View File

@@ -0,0 +1,28 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Service class."""
from primaite.common.enums import SoftwareState
class Service(object):
"""Service class."""
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
"""
Initialise a service.
:param name: The service name.
:param port: The service port.
:param software_state: The service SoftwareState.
"""
self.name: str = name
self.port: str = port
self.software_state: SoftwareState = software_state
self.patching_count: int = 0
def reduce_patching_count(self) -> None:
"""Reduces the patching count for the service."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self.software_state = SoftwareState.GOOD

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Configuration parameters for running experiments."""

View File

@@ -0,0 +1,166 @@
- item_type: PORTS
ports_list:
- port: '80'
- item_type: SERVICES
service_list:
- name: TCP
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: PC2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.4
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '4'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.5
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '5'
name: SWITCH2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.6
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '6'
name: SWITCH3
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.7
software_state: GOOD
file_system_state: GOOD
- item_type: LINK
id: '7'
name: link1
bandwidth: 1000000000
source: '1'
destination: '4'
- item_type: LINK
id: '8'
name: link2
bandwidth: 1000000000
source: '4'
destination: '2'
- item_type: LINK
id: '9'
name: link3
bandwidth: 1000000000
source: '2'
destination: '5'
- item_type: LINK
id: '10'
name: link4
bandwidth: 1000000000
source: '2'
destination: '6'
- item_type: LINK
id: '11'
name: link5
bandwidth: 1000000000
source: '5'
destination: '3'
- item_type: LINK
id: '12'
name: link6
bandwidth: 1000000000
source: '6'
destination: '3'
- item_type: GREEN_IER
id: '13'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '3'
destination: '2'
mission_criticality: 5
- item_type: RED_POL
id: '14'
start_step: 50
end_step: 50
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '15'
start_step: 60
end_step: 100
load: 1000000
protocol: TCP
port: '80'
source: '1'
destination: '2'
mission_criticality: 0
- item_type: RED_POL
id: '16'
start_step: 80
end_step: 80
targetNodeId: '2'
initiator: IER
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: ACL_RULE
id: '17'
permission: ALLOW
source: ANY
destination: ANY
protocol: ANY
port: ANY
position: 0

View File

@@ -0,0 +1,366 @@
- item_type: PORTS
ports_list:
- port: '80'
- item_type: SERVICES
service_list:
- name: TCP
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.11
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '2'
name: PC2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: PC3
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.13
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '4'
name: PC4
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.20.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '5'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '6'
name: IDS
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.4
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '7'
name: SWITCH2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '8'
name: LOP1
node_class: SERVICE
node_type: LOP
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '9'
name: SERVER1
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '10'
name: SERVER2
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.20.15
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: LINK
id: '11'
name: link1
bandwidth: 1000000000
source: '1'
destination: '5'
- item_type: LINK
id: '12'
name: link2
bandwidth: 1000000000
source: '2'
destination: '5'
- item_type: LINK
id: '13'
name: link3
bandwidth: 1000000000
source: '3'
destination: '5'
- item_type: LINK
id: '14'
name: link4
bandwidth: 1000000000
source: '4'
destination: '5'
- item_type: LINK
id: '15'
name: link5
bandwidth: 1000000000
source: '5'
destination: '6'
- item_type: LINK
id: '16'
name: link6
bandwidth: 1000000000
source: '5'
destination: '8'
- item_type: LINK
id: '17'
name: link7
bandwidth: 1000000000
source: '6'
destination: '7'
- item_type: LINK
id: '18'
name: link8
bandwidth: 1000000000
source: '8'
destination: '7'
- item_type: LINK
id: '19'
name: link9
bandwidth: 1000000000
source: '7'
destination: '9'
- item_type: LINK
id: '20'
name: link10
bandwidth: 1000000000
source: '7'
destination: '10'
- item_type: GREEN_IER
id: '21'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '1'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '22'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '2'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '23'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '9'
destination: '3'
mission_criticality: 5
- item_type: GREEN_IER
id: '24'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '4'
destination: '10'
mission_criticality: 2
- item_type: ACL_RULE
id: '25'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.10.14
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '26'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.10.14
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '27'
permission: ALLOW
source: 192.168.10.13
destination: 192.168.10.14
protocol: TCP
port: 80
position: 2
- item_type: ACL_RULE
id: '28'
permission: ALLOW
source: 192.168.20.14
destination: 192.168.20.15
protocol: TCP
port: 80
position: 3
- item_type: ACL_RULE
id: '29'
permission: ALLOW
source: 192.168.10.14
destination: 192.168.10.13
protocol: TCP
port: 80
position: 4
- item_type: ACL_RULE
id: '30'
permission: DENY
source: 192.168.10.11
destination: 192.168.20.15
protocol: TCP
port: 80
position: 5
- item_type: ACL_RULE
id: '31'
permission: DENY
source: 192.168.10.12
destination: 192.168.20.15
protocol: TCP
port: 80
position: 6
- item_type: ACL_RULE
id: '32'
permission: DENY
source: 192.168.10.13
destination: 192.168.20.15
protocol: TCP
port: 80
position: 7
- item_type: ACL_RULE
id: '33'
permission: DENY
source: 192.168.20.14
destination: 192.168.10.14
protocol: TCP
port: 80
position: 8
- item_type: RED_POL
id: '34'
start_step: 20
end_step: 20
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_POL
id: '35'
start_step: 20
end_step: 20
targetNodeId: '2'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '36'
start_step: 30
end_step: 128
load: 440000000
protocol: TCP
port: '80'
source: '1'
destination: '9'
mission_criticality: 0
- item_type: RED_IER
id: '37'
start_step: 30
end_step: 128
load: 440000000
protocol: TCP
port: '80'
source: '2'
destination: '9'
mission_criticality: 0
- item_type: RED_POL
id: '38'
start_step: 30
end_step: 30
targetNodeId: '9'
initiator: IER
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA

View File

@@ -0,0 +1,164 @@
- item_type: PORTS
ports_list:
- port: '80'
- item_type: SERVICES
service_list:
- name: TCP
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '2'
name: PC2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '4'
name: SERVER1
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.4
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: LINK
id: '5'
name: link1
bandwidth: 1000000000
source: '1'
destination: '3'
- item_type: LINK
id: '6'
name: link2
bandwidth: 1000000000
source: '2'
destination: '3'
- item_type: LINK
id: '7'
name: link3
bandwidth: 1000000000
source: '3'
destination: '4'
- item_type: GREEN_IER
id: '8'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '1'
destination: '4'
mission_criticality: 1
- item_type: GREEN_IER
id: '9'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '2'
destination: '4'
mission_criticality: 1
- item_type: GREEN_IER
id: '10'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '4'
destination: '2'
mission_criticality: 5
- item_type: ACL_RULE
id: '11'
permission: ALLOW
source: 192.168.1.2
destination: 192.168.1.4
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '12'
permission: ALLOW
source: 192.168.1.3
destination: 192.168.1.4
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '13'
permission: ALLOW
source: 192.168.1.4
destination: 192.168.1.3
protocol: TCP
port: 80
position: 2
- item_type: RED_POL
id: '14'
start_step: 20
end_step: 20
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '15'
start_step: 30
end_step: 256
load: 10000000
protocol: TCP
port: '80'
source: '1'
destination: '4'
mission_criticality: 0
- item_type: RED_POL
id: '16'
start_step: 40
end_step: 40
targetNodeId: '4'
initiator: IER
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA

View File

@@ -0,0 +1,546 @@
- item_type: PORTS
ports_list:
- port: '80'
- port: '1433'
- port: '53'
- item_type: SERVICES
service_list:
- name: TCP
- name: TCP_SQL
- name: UDP
- item_type: NODE
node_id: '1'
name: CLIENT_1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.11
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '2'
name: CLIENT_2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: SWITCH_1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.10.1
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '4'
name: SECURITY_SUITE
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.10
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '5'
name: MANAGEMENT_CONSOLE
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '6'
name: SWITCH_2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.2.1
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '7'
name: WEB_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.10
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: TCP_SQL
port: '1433'
state: GOOD
- item_type: NODE
node_id: '8'
name: DATABASE_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: TCP_SQL
port: '1433'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '9'
name: BACKUP_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.16
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: LINK
id: '10'
name: LINK_1
bandwidth: 1000000000
source: '1'
destination: '3'
- item_type: LINK
id: '11'
name: LINK_2
bandwidth: 1000000000
source: '2'
destination: '3'
- item_type: LINK
id: '12'
name: LINK_3
bandwidth: 1000000000
source: '3'
destination: '4'
- item_type: LINK
id: '13'
name: LINK_4
bandwidth: 1000000000
source: '3'
destination: '5'
- item_type: LINK
id: '14'
name: LINK_5
bandwidth: 1000000000
source: '4'
destination: '6'
- item_type: LINK
id: '15'
name: LINK_6
bandwidth: 1000000000
source: '5'
destination: '6'
- item_type: LINK
id: '16'
name: LINK_7
bandwidth: 1000000000
source: '6'
destination: '7'
- item_type: LINK
id: '17'
name: LINK_8
bandwidth: 1000000000
source: '6'
destination: '8'
- item_type: LINK
id: '18'
name: LINK_9
bandwidth: 1000000000
source: '6'
destination: '9'
- item_type: GREEN_IER
id: '19'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '1'
destination: '7'
mission_criticality: 5
- item_type: GREEN_IER
id: '20'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '7'
destination: '1'
mission_criticality: 5
- item_type: GREEN_IER
id: '21'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '2'
destination: '7'
mission_criticality: 5
- item_type: GREEN_IER
id: '22'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '7'
destination: '2'
mission_criticality: 5
- item_type: GREEN_IER
id: '23'
start_step: 1
end_step: 256
load: 5000
protocol: TCP_SQL
port: '1433'
source: '7'
destination: '8'
mission_criticality: 5
- item_type: GREEN_IER
id: '24'
start_step: 1
end_step: 256
load: 100000
protocol: TCP_SQL
port: '1433'
source: '8'
destination: '7'
mission_criticality: 5
- item_type: GREEN_IER
id: '25'
start_step: 1
end_step: 256
load: 50000
protocol: TCP
port: '80'
source: '1'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '26'
start_step: 1
end_step: 256
load: 50000
protocol: TCP
port: '80'
source: '2'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '27'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '7'
mission_criticality: 1
- item_type: GREEN_IER
id: '28'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '7'
destination: '5'
mission_criticality: 1
- item_type: GREEN_IER
id: '29'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '8'
mission_criticality: 1
- item_type: GREEN_IER
id: '30'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '8'
destination: '5'
mission_criticality: 1
- item_type: GREEN_IER
id: '31'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '9'
mission_criticality: 1
- item_type: GREEN_IER
id: '32'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '9'
destination: '5'
mission_criticality: 1
- item_type: ACL_RULE
id: '33'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 0
- item_type: ACL_RULE
id: '34'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 1
- item_type: ACL_RULE
id: '35'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 2
- item_type: ACL_RULE
id: '36'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 3
- item_type: ACL_RULE
id: '37'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.10.11
protocol: ANY
port: ANY
position: 4
- item_type: ACL_RULE
id: '38'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.10.12
protocol: ANY
port: ANY
position: 5
- item_type: ACL_RULE
id: '39'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 6
- item_type: ACL_RULE
id: '40'
permission: ALLOW
source: 192.168.2.14
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 7
- item_type: ACL_RULE
id: '41'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 8
- item_type: ACL_RULE
id: '42'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 9
- item_type: ACL_RULE
id: '43'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 10
- item_type: ACL_RULE
id: '44'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 11
- item_type: ACL_RULE
id: '45'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 12
- item_type: ACL_RULE
id: '46'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 13
- item_type: ACL_RULE
id: '47'
permission: ALLOW
source: 192.168.2.14
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 14
- item_type: ACL_RULE
id: '48'
permission: ALLOW
source: 192.168.2.16
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 15
- item_type: ACL_RULE
id: '49'
permission: DENY
source: ANY
destination: ANY
protocol: ANY
port: ANY
position: 16
- item_type: RED_POL
id: '50'
start_step: 50
end_step: 50
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: UDP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '51'
start_step: 75
end_step: 105
load: 10000
protocol: UDP
port: '53'
source: '1'
destination: '8'
mission_criticality: 0
- item_type: RED_POL
id: '52'
start_step: 100
end_step: 100
targetNodeId: '8'
initiator: IER
type: SERVICE
protocol: UDP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_POL
id: '53'
start_step: 105
end_step: 105
targetNodeId: '8'
initiator: SERVICE
type: FILE
protocol: NA
state: CORRUPT
sourceNodeId: '8'
sourceNodeService: UDP
sourceNodeServiceState: COMPROMISED
- item_type: RED_POL
id: '54'
start_step: 105
end_step: 105
targetNodeId: '8'
initiator: SERVICE
type: SERVICE
protocol: TCP_SQL
state: COMPROMISED
sourceNodeId: '8'
sourceNodeService: UDP
sourceNodeServiceState: COMPROMISED
- item_type: RED_POL
id: '55'
start_step: 125
end_step: 125
targetNodeId: '7'
initiator: SERVICE
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: '8'
sourceNodeService: TCP_SQL
sourceNodeServiceState: COMPROMISED

View File

@@ -0,0 +1,168 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# observation space
observation_space:
flatten: true
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 30
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,141 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, List, Union
import yaml
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Convert a legacy lay down config to the new format.
:param legacy_config: A legacy lay down config.
"""
field_conversion_map = {
"itemType": "item_type",
"portsList": "ports_list",
"serviceList": "service_list",
"baseType": "node_class",
"nodeType": "node_type",
"hardwareState": "hardware_state",
"softwareState": "software_state",
"startStep": "start_step",
"endStep": "end_step",
"fileSystemState": "file_system_state",
"ipAddress": "ip_address",
"missionCriticality": "mission_criticality",
}
new_config = []
for item in legacy_config:
if "itemType" in item:
if item["itemType"] in ["ACTIONS", "STEPS"]:
continue
new_dict = {}
for key in item.keys():
conversion_key = field_conversion_map.get(key)
if key == "id" and "itemType" in item:
if item["itemType"] == "NODE":
conversion_key = "node_id"
if conversion_key:
new_dict[conversion_key] = item[key]
else:
new_dict[key] = item[key]
new_config.append(new_dict)
return new_config
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
"""
Read in a lay down config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise False.
:return: The lay down config as a dict.
:raises ValueError: If the file_path does not exist.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading lay down config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_lay_down_config(config)
except KeyError:
msg = (
f"Failed to convert lay down config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
return config
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def ddos_basic_two_config_path() -> Path:
"""
The path to the example lay_down_config_2_DDOS_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def dos_very_basic_config_path() -> Path:
"""
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def data_manipulation_config_path() -> Path:
"""
The path to the example lay_down_config_5_data_manipulation.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -0,0 +1,438 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
from dataclasses import dataclass, field
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
import yaml
from primaite import getLogger, PRIMAITE_PATHS
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
HardCodedAgentView,
RulePermissionType,
SB3OutputVerboseLevel,
SessionType,
)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training"
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
@dataclass()
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The AgentFramework"
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
"The DeepLearningFramework"
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
"The AgentIdentifier"
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
"The view the deterministic hard-coded agent has of the environment"
random_red_agent: bool = False
"Creates Random Red Agent Attacks"
action_type: ActionType = ActionType.ANY
"The ActionType to use"
num_train_episodes: int = 10
"The number of episodes to train over during an training session"
num_train_steps: int = 256
"The number of steps in an episode during an training session"
num_eval_episodes: int = 1
"The number of episodes to train over during an evaluation session"
num_eval_steps: int = 256
"The number of steps in an episode during an evaluation session"
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes"
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
"The observation space config dict"
time_delay: int = 10
"The delay between steps (ms). Applies to generic agents only"
# file
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: bool = False
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
"File path and file name of agent if you're loading one in"
# Environment
observation_space_high_value: int = 1000000000
"The high value for the observation space"
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
"Stable Baselines3 learn/eval output verbosity level"
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
max_number_acl_rules: int = 30
"Sets a limit for number of acl rules allowed in the list and environment."
# Reward values
# Generic
all_ok: float = 0
# Node Hardware State
off_should_be_on: float = -0.001
off_should_be_resetting: float = -0.0005
on_should_be_off: float = -0.0002
on_should_be_resetting: float = -0.0005
resetting_should_be_on: float = -0.0005
resetting_should_be_off: float = -0.0002
resetting: float = -0.0003
# Node Software or Service State
good_should_be_patching: float = 0.0002
good_should_be_compromised: float = 0.0005
good_should_be_overwhelmed: float = 0.0005
patching_should_be_good: float = -0.0005
patching_should_be_compromised: float = 0.0002
patching_should_be_overwhelmed: float = 0.0002
patching: float = -0.0003
compromised_should_be_good: float = -0.002
compromised_should_be_patching: float = -0.002
compromised_should_be_overwhelmed: float = -0.002
compromised: float = -0.002
overwhelmed_should_be_good: float = -0.002
overwhelmed_should_be_patching: float = -0.002
overwhelmed_should_be_compromised: float = -0.002
overwhelmed: float = -0.002
# Node File System State
good_should_be_repairing: float = 0.0002
good_should_be_restoring: float = 0.0002
good_should_be_corrupt: float = 0.0005
good_should_be_destroyed: float = 0.001
repairing_should_be_good: float = -0.0005
repairing_should_be_restoring: float = 0.0002
repairing_should_be_corrupt: float = 0.0002
repairing_should_be_destroyed: float = 0.0000
repairing: float = -0.0003
restoring_should_be_good: float = -0.001
restoring_should_be_repairing: float = -0.0002
restoring_should_be_corrupt: float = 0.0001
restoring_should_be_destroyed: float = 0.0002
restoring: float = -0.0006
corrupt_should_be_good: float = -0.001
corrupt_should_be_repairing: float = -0.001
corrupt_should_be_restoring: float = -0.001
corrupt_should_be_destroyed: float = 0.0002
corrupt: float = -0.001
destroyed_should_be_good: float = -0.002
destroyed_should_be_repairing: float = -0.002
destroyed_should_be_restoring: float = -0.002
destroyed_should_be_corrupt: float = -0.002
destroyed: float = -0.002
scanning: float = -0.0002
# IER status
red_ier_running: float = -0.0005
green_ier_blocked: float = -0.001
# Patching / Reset durations
os_patching_duration: int = 5
"The time taken to patch the OS"
node_reset_duration: int = 5
"The time taken to reset a node (hardware)"
node_booting_duration: int = 3
"The Time taken to turn on the node"
node_shutdown_duration: int = 2
"The time taken to turn off the node"
service_patching_duration: int = 5
"The time taken to patch a service"
file_system_repairing_limit: int = 5
"The time take to repair the file system"
file_system_restoring_limit: int = 5
"The time take to restore the file system"
file_system_scanning_limit: int = 5
"The time taken to scan the file system"
deterministic: bool = False
"If true, the training will be deterministic"
seed: Optional[int] = None
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
:param config_dict: The training config dict.
:return: The instance of TrainingConfig.
"""
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
"agent_identifier": AgentIdentifier,
"action_type": ActionType,
"session_type": SessionType,
"sb3_output_verbose_level": SB3OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView,
"implicit_acl_rule": RulePermissionType,
}
# convert the string representation of enums into the actual enum values themselves?
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True) -> Dict:
"""
Serialise the ``TrainingConfig`` as dict.
:param json_serializable: If True, Enums are converted to their
string name.
:return: The ``TrainingConfig`` as a dict.
"""
data = self.__dict__
if json_serializable:
data["agent_framework"] = self.agent_framework.name
data["deep_learning_framework"] = self.deep_learning_framework.name
data["agent_identifier"] = self.agent_identifier.name
data["action_type"] = self.action_type.name
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
data["session_type"] = self.session_type.name
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
data["implicit_acl_rule"] = self.implicit_acl_rule.name
return data
def __str__(self) -> str:
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
tc = f"{self.agent_framework}, "
if self.agent_framework is AgentFramework.RLLIB:
tc += f"{self.deep_learning_framework}, "
tc += f"{self.agent_identifier}, "
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={obs_str}, "
if self.session_type is SessionType.TRAIN:
tc += f"{self.num_train_episodes} episodes @ "
tc += f"{self.num_train_steps} steps"
elif self.session_type is SessionType.EVAL:
tc += f"{self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
else:
tc += f"Training: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
return tc
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise
False.
:return: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:raises ValueError: If the file_path does not exist.
:raises TypeError: When the TrainingConfig object cannot be created
using the values from the config file read from ``file_path``.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading training config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_training_config_dict(config)
except KeyError as e:
msg = (
f"Failed to convert training config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
raise e
try:
return TrainingConfig.from_dict(config)
except TypeError as e:
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
_LOGGER.critical(msg, exc_info=True)
raise e
msg = f"Cannot load the training config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
agent_framework: AgentFramework = AgentFramework.SB3,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_train_steps: int = 256,
num_eval_steps: int = 256,
num_train_episodes: int = 10,
num_eval_episodes: int = 1,
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param agent_identifier: The red agent identifier to use as legacy
training configs don't have agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_train_steps: The number of train steps to set as legacy training configs
don't have num_train_steps values.
:param num_eval_steps: The number of eval steps to set as legacy training configs
don't have num_eval_steps values.
:param num_train_episodes: The number of train episodes to set as legacy training configs
don't have num_train_episodes values.
:param num_eval_episodes: The number of eval episodes to set as legacy training configs
don't have num_eval_episodes values.
:return: The converted training config dict.
"""
config_dict = {
"agent_framework": agent_framework.name,
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_train_steps": num_train_steps,
"num_eval_steps": num_eval_steps,
"num_train_episodes": num_train_episodes,
"num_eval_episodes": num_eval_episodes,
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:
config_dict[new_key] = value
return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
"""
Maps legacy training config keys to the new format keys.
:param legacy_key: A legacy training config key.
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": None,
"numEpisodes": "num_train_episodes",
"numSteps": "num_train_steps",
"timeDelay": "time_delay",
"configFilename": None,
"sessionType": "session_type",
"loadAgent": "load_agent",
"agentLoadFile": "agent_load_file",
"observationSpaceHighValue": "observation_space_high_value",
"allOk": "all_ok",
"offShouldBeOn": "off_should_be_on",
"offShouldBeResetting": "off_should_be_resetting",
"onShouldBeOff": "on_should_be_off",
"onShouldBeResetting": "on_should_be_resetting",
"resettingShouldBeOn": "resetting_should_be_on",
"resettingShouldBeOff": "resetting_should_be_off",
"resetting": "resetting",
"goodShouldBePatching": "good_should_be_patching",
"goodShouldBeCompromised": "good_should_be_compromised",
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
"patchingShouldBeGood": "patching_should_be_good",
"patchingShouldBeCompromised": "patching_should_be_compromised",
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
"patching": "patching",
"compromisedShouldBeGood": "compromised_should_be_good",
"compromisedShouldBePatching": "compromised_should_be_patching",
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
"compromised": "compromised",
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
"overwhelmed": "overwhelmed",
"goodShouldBeRepairing": "good_should_be_repairing",
"goodShouldBeRestoring": "good_should_be_restoring",
"goodShouldBeCorrupt": "good_should_be_corrupt",
"goodShouldBeDestroyed": "good_should_be_destroyed",
"repairingShouldBeGood": "repairing_should_be_good",
"repairingShouldBeRestoring": "repairing_should_be_restoring",
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
"repairing": "repairing",
"restoringShouldBeGood": "restoring_should_be_good",
"restoringShouldBeRepairing": "restoring_should_be_repairing",
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
"restoring": "restoring",
"corruptShouldBeGood": "corrupt_should_be_good",
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
"corrupt": "corrupt",
"destroyedShouldBeGood": "destroyed_should_be_good",
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
"destroyed": "destroyed",
"scanning": "scanning",
"redIerRunning": "red_ier_running",
"greenIerBlocked": "green_ier_blocked",
"osPatchingDuration": "os_patching_duration",
"nodeResetDuration": "node_reset_duration",
"nodeBootingDuration": "node_booting_duration",
"nodeShutdownDuration": "node_shutdown_duration",
"servicePatchingDuration": "service_patching_duration",
"fileSystemRepairingLimit": "file_system_repairing_limit",
"fileSystemRestoringLimit": "file_system_restoring_limit",
"fileSystemScanningLimit": "file_system_scanning_limit",
}
return key_mapping[legacy_key]

View File

@@ -0,0 +1,15 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum
class PlotlyTemplate(Enum):
"""The built-in plotly templates."""
PLOTLY = "plotly"
PLOTLY_WHITE = "plotly_white"
PLOTLY_DARK = "plotly_dark"
GGPLOT2 = "ggplot2"
SEABORN = "seaborn"
SIMPLE_WHITE = "simple_white"
NONE = "none"

View File

@@ -0,0 +1,73 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Dict, Optional, Union
import plotly.graph_objects as go
import polars as pl
import yaml
from plotly.graph_objs import Figure
from primaite import PRIMAITE_PATHS
def get_plotly_config() -> Dict:
"""Get the plotly config from primaite_config.yaml."""
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["session"]["outputs"]["plots"]
def plot_av_reward_per_episode(
av_reward_per_episode_csv: Union[str, Path],
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
"""
Plot the average reward per episode from a csv session output.
:param av_reward_per_episode_csv: The average reward per episode csv
file path.
:param title: The plot title. This is optional.
:param subtitle: The plot subtitle. This is optional.
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
"""
df = pl.read_csv(av_reward_per_episode_csv)
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
fig.add_trace(
go.Scatter(
x=df["Episode"],
y=df["Average Reward"],
mode="lines",
name="Mean Reward per Episode",
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
"rangeslider": {"visible": config["range_slider"]},
},
yaxis={"title": "Average Reward"},
title=title,
showlegend=False,
)
return fig

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""

View File

@@ -0,0 +1,735 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from logging import Logger
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
from gym import spaces
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# This dependency is only needed for type hints,
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite") -> None:
"""
Initialise observation component.
:param env: Primaite training environment.
:type env: Primaite
"""
_LOGGER.info(f"Initialising {self} observation component")
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
self.structure: List[str]
return NotImplemented
@abstractmethod
def update(self) -> None:
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@abstractmethod
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
return NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""
Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
The number of columns is 4 plus one per service. They are:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
"""
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeLinkTable observation space component.
:param env: Training environment.
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
num_items = self.env.num_links + self.env.num_nodes
num_columns = self.env.num_services + self._FIXED_PARAMETERS
observation_shape = (num_items, num_columns)
# 2. Create Observation space
self.space = spaces.Box(
low=0,
high=self._MAX_VAL,
shape=observation_shape,
dtype=self._DATA_TYPE,
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
item_index = 0
nodes = self.env.nodes
links = self.env.links
# Do nodes first
for _, node in nodes.items():
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][3] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][service_index] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.env.services_list:
self.current_observation[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for _, link in links.items():
self.current_observation[item_index][0] = int(link.get_id())
self.current_observation[item_index][1] = 0
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
structure = []
for i, node in enumerate(nodes):
node_id = node.node_id
node_labels = [
f"node_{node_id}_id",
f"node_{node_id}_hardware_status",
f"node_{node_id}_os_status",
f"node_{node_id}_fs_status",
]
for j, serv in enumerate(self.env.services_list):
node_labels.append(f"node_{node_id}_service_{serv}_status")
structure.extend(node_labels)
for i, link in enumerate(links):
link_id = link.id
link_labels = [
f"link_{link_id}_id",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
]
for j, serv in enumerate(self.env.services_list):
link_labels.append(f"link_{link_id}_service_{serv}_load")
structure.extend(link_labels)
return structure
class NodeStatuses(AbstractObservationComponent):
"""
Flat list of nodes' hardware, OS, file system, and service states.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each node has 3 elements plus 1 per service. It will have the following structure:
.. code-block::
[
node1 hardware state,
node1 OS state,
node1 file system state,
node1 service1 state,
node1 service2 state,
node1 serviceN state (one for each service),
node2 hardware state,
node2 OS state,
node2 file system state,
node2 service1 state,
node2 service2 state,
node2 serviceN state (one for each service),
...
]
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeStatuses observation component.
:param env: Training environment.
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
node_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
obs = []
for _, node in self.env.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
service_states = [0] * self.env.num_services
if isinstance(node, ActiveNode):
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*service_states,
]
)
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
for state in HardwareState:
structure.append(f"node_{node_id}_hardware_state_{state.name}")
structure.append(f"node_{node_id}_software_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_software_state_{state.name}")
structure.append(f"node_{node_id}_file_system_state_NONE")
for state in FileSystemState:
structure.append(f"node_{node_id}_file_system_state_{state.name}")
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
return structure
class LinkTrafficLevels(AbstractObservationComponent):
"""
Flat list of traffic levels encoded into banded categories.
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
* 0 = No traffic (0% of bandwidth)
* 1 = No traffic (0%-33% of bandwidth)
* 2 = No traffic (33%-66% of bandwidth)
* 3 = No traffic (66%-100% of bandwidth)
* 4 = No traffic (100% of bandwidth)
.. note::
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
"""
_DATA_TYPE: type = np.int64
def __init__(
self,
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
) -> None:
"""
Initialise a LinkTrafficLevels observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical
value, defaults to 5
:type quantisation_levels: int, optional
"""
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
f"Resetting to default value (5)"
)
_LOGGER.warning(_msg)
quantisation_levels = 5
super().__init__(env)
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
if not self._combine_service_traffic:
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
obs.append(int(traffic_level))
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
link_id = link.id
if self._combine_service_traffic:
protocols = ["overall"]
else:
protocols = [protocol.name for protocol in link.protocol_list]
for p in protocols:
for i in range(self._quantisation_levels):
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
return structure
class AccessControlList(AbstractObservationComponent):
"""Flat list of all the Access Control Rules in the Access Control List.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each ACL Rule has 6 elements. It will have the following structure:
.. code-block::
[
acl_rule1 permission,
acl_rule1 source_ip,
acl_rule1 dest_ip,
acl_rule1 protocol,
acl_rule1 port,
acl_rule1 position,
acl_rule2 permission,
acl_rule2 source_ip,
acl_rule2 dest_ip,
acl_rule2 protocol,
acl_rule2 port,
acl_rule2 position,
...
]
Terms (for ACL Observation Space):
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""
Initialise an AccessControlList observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
# Number of ACL rules incremented by 1 for positions starting at index 0.
acl_shape = [
len(RulePermissionType),
len(env.nodes) + 2,
len(env.nodes) + 2,
len(env.services_list) + 2,
len(env.ports_list) + 2,
env.max_number_acl_rules,
]
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission
source_ip = acl_rule.source_ip
dest_ip = acl_rule.dest_ip
protocol = acl_rule.protocol
port = acl_rule.port
position = index
# Map each ACL attribute from what it was to an integer to fit the observation space
source_ip_int = None
dest_ip_int = None
if permission == RulePermissionType.DENY:
permission_int = 1
else:
permission_int = 2
if source_ip == "ANY":
source_ip_int = 1
else:
# Map Node ID (+ 1) to source IP address
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == source_ip:
source_ip_int = int(node.node_id) + 1
break
if dest_ip == "ANY":
dest_ip_int = 1
else:
# Map Node ID (+ 1) to dest IP address
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip:
dest_ip_int = int(node.node_id) + 1
if protocol == "ANY":
protocol_int = 1
else:
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
try:
protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = None
if port == "ANY":
port_int = 1
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port) + 2
else:
_LOGGER.info(f"Port {port} could not be found.")
port_int = None
# Add to current obs
obs.extend(
[
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
)
else:
# The Nothing or NA representation of 'NONE' ACL rules
obs.extend([0, 0, 0, 0, 0, 0])
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for acl_rule in self.env.acl.acl:
acl_rule_id = self.env.acl.acl.index(acl_rule)
for permission in RulePermissionType:
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
for service in self.env.services_list:
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
for port in self.env.ports_list:
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
return structure
class ObservationsHandler:
"""
Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components. Each component can also define
further parameters to make them more flexible.
"""
_REGISTRY: Final[Dict[str, type]] = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
"ACCESS_CONTROL_LIST": AccessControlList,
}
def __init__(self) -> None:
"""Initialise the observation handler."""
self.registered_obs_components: List[AbstractObservationComponent] = []
# internal the observation space (unflattened version of space if flatten=True)
self._space: spaces.Space
# flattened version of the observation space
self._flat_space: spaces.Space
self._observation: Union[Tuple[np.ndarray], np.ndarray]
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
def update_obs(self) -> None:
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
obs.update()
current_obs.append(obs.current_observation)
if len(current_obs) == 1:
self._observation = current_obs[0]
else:
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent) -> None:
"""
Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent) -> None:
"""
Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self) -> None:
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# if there are multiple components, build a composite tuple space
if len(component_spaces) == 1:
self._space = component_spaces[0]
else:
self._space = spaces.Tuple(component_spaces)
if len(component_spaces) > 0:
self._flat_space = spaces.flatten_space(self._space)
else:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self) -> spaces.Space:
"""Observation space, return the flattened version if flatten is True."""
if len(self.registered_obs_components) > 1:
return self._flat_space
else:
return self._space
@property
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""Current observation, return the flattened version if flatten is True."""
if len(self.registered_obs_components) > 1:
return self._flat_observation
else:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
"""
Parse a config dictinary, return a new observation handler populated with new observation component objects.
The expected format for the config dictionary is:
.. code-block:: python
config = {
components: [
{
"name": "<COMPONENT1_NAME>"
},
{
"name": "<COMPONENT2_NAME>"
"options": {"opt1": val1, "opt2": val2}
},
{
...
},
]
}
:return: Observation handler
:rtype: primaite.environment.observations.ObservationsHandler
"""
# Instantiate the handler
handler = cls()
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls._REGISTRY[comp_type]
# Create the component with options from the YAML
options = component_cfg.get("options") or {}
component = comp_class(env, **options)
handler.register(component)
handler.update_obs()
return handler
def describe_structure(self) -> List[str]:
"""
Create a list of names for the features of the obs space.
The order of labels follows the flattened version of the space.
"""
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
# to fake it. each component has to just hard-code the expected label order after flattening...
labels = []
for obs_comp in self.registered_obs_components:
labels.extend(obs_comp.structure)
return labels

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,386 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements reward function."""
from logging import Logger
from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
if TYPE_CHECKING:
from primaite.config.training_config import TrainingConfig
from primaite.pol.ier import IER
_LOGGER: Logger = getLogger(__name__)
def calculate_reward_function(
initial_nodes: Dict[str, NodeUnion],
final_nodes: Dict[str, NodeUnion],
reference_nodes: Dict[str, NodeUnion],
green_iers: Dict[str, "IER"],
green_iers_reference: Dict[str, "IER"],
red_iers: Dict[str, "IER"],
step_count: int,
config_values: "TrainingConfig",
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
Args:
initial_nodes: The nodes before red and blue agents take effect
final_nodes: The nodes after red and blue agents take effect
reference_nodes: The nodes if there had been no red or blue effect
green_iers: The green IERs (should be running)
red_iers: Should be stopeed (ideally) by the blue agent
step_count: current step
config_values: Config values
"""
reward_value: float = 0.0
# For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items():
initial_node = initial_nodes[node_key]
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if ier_value.get_is_running():
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
if live_blocked and not reference_blocked:
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value
def score_node_operating_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the hardware state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score: float = 0.0
final_node_operating_state = final_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
if final_node_operating_state == reference_node_operating_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_operating_state == HardwareState.ON:
if final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_on
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_on
else:
pass
elif reference_node_operating_state == HardwareState.OFF:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_off
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_off
else:
pass
elif reference_node_operating_state == HardwareState.RESETTING:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_resetting
elif final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_resetting
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting
else:
pass
else:
pass
return score
def score_node_os_state(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the Software State of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score: float = 0.0
final_node_os_state = final_node.software_state
reference_node_os_state = reference_node.software_state
if final_node_os_state == reference_node_os_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_os_state == SoftwareState.GOOD:
if final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
else:
pass
elif reference_node_os_state == SoftwareState.PATCHING:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif reference_node_os_state == SoftwareState.COMPROMISED:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised
else:
pass
else:
pass
return score
def score_node_service_state(
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the service state(s) of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score: float = 0.0
final_node_services: Dict[str, Service] = final_node.services
reference_node_services: Dict[str, Service] = reference_node.services
for service_key, final_service in final_node_services.items():
reference_service = reference_node_services[service_key]
final_service = final_node_services[service_key]
if final_service.software_state == reference_service.software_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final state of node (i.e. at every step)
if reference_service.software_state == SoftwareState.GOOD:
if final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_good
else:
pass
elif reference_service.software_state == SoftwareState.PATCHING:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif reference_service.software_state == SoftwareState.COMPROMISED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised
else:
pass
elif reference_service.software_state == SoftwareState.OVERWHELMED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_overwhelmed
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_overwhelmed
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed
else:
pass
else:
pass
return score
def score_node_file_system(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the file system state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
"""
score: float = 0.0
final_node_file_system_state = final_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual
final_node_scanning_state = final_node.file_system_scanning
reference_node_scanning_state = reference_node.file_system_scanning
# File System State
if final_node_file_system_state == reference_node_file_system_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final state of node (i.e. at every step)
if reference_node_file_system_state == FileSystemState.GOOD:
if final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_good
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_good
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_good
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_good
else:
pass
elif reference_node_file_system_state == FileSystemState.REPAIRING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_repairing
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_repairing
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_repairing
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_repairing
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing
else:
pass
elif reference_node_file_system_state == FileSystemState.RESTORING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_restoring
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_restoring
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_restoring
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_restoring
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring
else:
pass
elif reference_node_file_system_state == FileSystemState.CORRUPT:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_corrupt
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_corrupt
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_corrupt
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_corrupt
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt
else:
pass
elif reference_node_file_system_state == FileSystemState.DESTROYED:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_destroyed
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_destroyed
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_destroyed
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_destroyed
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed
else:
pass
else:
pass
# Scanning State
if final_node_scanning_state == reference_node_scanning_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# We're scanning the file system which incurs a penalty (as it slows down systems)
score += config_values.scanning
return score

View File

@@ -0,0 +1,11 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
class PrimaiteError(Exception):
"""The root PrimAITe Error."""
pass
class RLlibAgentError(PrimaiteError):
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
pass

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Network connections between nodes in the simulation."""

114
src/primaite/links/link.py Normal file
View File

@@ -0,0 +1,114 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The link class."""
from typing import List
from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
"""
Initialise a Link within the simulated network.
:param _id: The IER id
:param _bandwidth: The bandwidth of the link (bps)
:param _source_node_name: The name of the source node
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
self.id: str = _id
self.bandwidth: int = _bandwidth
self.source_node_name: str = _source_node_name
self.dest_node_name: str = _dest_node_name
self.protocol_list: List[Protocol] = []
# Add the default protocols
for protocol_name in _services:
self.add_protocol(protocol_name)
def add_protocol(self, _protocol: str) -> None:
"""
Adds a new protocol to the list of protocols on this link.
Args:
_protocol: The protocol to be added (enum)
"""
self.protocol_list.append(Protocol(_protocol))
def get_id(self) -> str:
"""
Gets link ID.
Returns:
Link ID
"""
return self.id
def get_source_node_name(self) -> str:
"""
Gets source node name.
Returns:
Source node name
"""
return self.source_node_name
def get_dest_node_name(self) -> str:
"""
Gets destination node name.
Returns:
Destination node name
"""
return self.dest_node_name
def get_bandwidth(self) -> int:
"""
Gets bandwidth of link.
Returns:
Link bandwidth (bps)
"""
return self.bandwidth
def get_protocol_list(self) -> List[Protocol]:
"""
Gets list of protocols on this link.
Returns:
List of protocols on this link
"""
return self.protocol_list
def get_current_load(self) -> int:
"""
Gets current total load on this link.
Returns:
Total load on this link (bps)
"""
total_load = 0
for protocol in self.protocol_list:
total_load += protocol.get_load()
return total_load
def add_protocol_load(self, _protocol: str, _load: int) -> None:
"""
Adds a loading to a protocol on this link.
Args:
_protocol: The protocol to load
_load: The amount to load (bps)
"""
for protocol in self.protocol_list:
if protocol.get_name() == _protocol:
protocol.add_load(_load)
else:
pass
def clear_traffic(self) -> None:
"""Clears all traffic on this link."""
for protocol in self.protocol_list:
protocol.clear_load()

57
src/primaite/main.py Normal file
View File

@@ -0,0 +1,57 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run(
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
Run the PrimAITE Session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
session = PrimaiteSession(
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
)
session.setup()
session.learn()
session.evaluate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
args = parser.parse_args()
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Nodes represent network hosts in the simulation."""

View File

@@ -0,0 +1,208 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class ActiveNode(Node):
"""Active Node class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
ip_address: str,
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
) -> None:
"""
Initialise an active node.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
self.ip_address: str = ip_address
# Related to Software
self._software_state: SoftwareState = software_state
self.patching_count: int = 0
# Related to File System
self.file_system_state_actual: FileSystemState = file_system_state
self.file_system_state_observed: FileSystemState = file_system_state
self.file_system_scanning: bool = False
self.file_system_scanning_count: int = 0
self.file_system_action_count: int = 0
@property
def software_state(self) -> SoftwareState:
"""
Get the software_state.
:return: The software_state.
"""
return self._software_state
@software_state.setter
def software_state(self, software_state: SoftwareState) -> None:
"""
Get the software_state.
:param software_state: Software State.
"""
if self.hardware_state != HardwareState.OFF:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = self.config_values.os_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be "
f"changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
"""
Sets Software State if the node is not compromised.
Args:
software_state: Software State
"""
if self.hardware_state != HardwareState.OFF:
if self._software_state != SoftwareState.COMPROMISED:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = self.config_values.os_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be changed."
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.software_state:{self._software_state}"
)
def update_os_patching_status(self) -> None:
"""Updates operating system status based on patching cycle."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self._software_state = SoftwareState.GOOD
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed).
Args:
file_system_state: File system state
"""
if self.hardware_state != HardwareState.OFF:
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so File System State "
f"cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed) if not in a compromised state.
Use for green PoL to prevent it overturning a compromised state
Args:
file_system_state: File system state
"""
if self.hardware_state != HardwareState.OFF:
if (
self.file_system_state_actual != FileSystemState.CORRUPT
and self.file_system_state_actual != FileSystemState.DESTROYED
):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so File System State (if not "
f"compromised) cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def start_file_system_scan(self) -> None:
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self) -> None:
"""Updates file system status based on scanning/restore/repair cycle."""
# Deprecate both the action count (for restoring or reparing) and the scanning count
self.file_system_action_count -= 1
self.file_system_scanning_count -= 1
# Reparing / Restoring updates
if self.file_system_action_count <= 0:
self.file_system_action_count = 0
if (
self.file_system_state_actual == FileSystemState.REPAIRING
or self.file_system_state_actual == FileSystemState.RESTORING
):
self.file_system_state_actual = FileSystemState.GOOD
self.file_system_state_observed = FileSystemState.GOOD
# Scanning updates
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
self.file_system_state_observed = self.file_system_state_actual
self.file_system_scanning = False
self.file_system_scanning_count = 0
def update_resetting_status(self) -> None:
"""Updates the reset count & makes software and file state to GOOD."""
super().update_resetting_status()
if self.resetting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD
def update_booting_status(self) -> None:
"""Updates the booting software and file state to GOOD."""
super().update_booting_status()
if self.booting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD

View File

@@ -0,0 +1,79 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The base Node class."""
from typing import Final
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
class Node:
"""Node class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
) -> None:
"""
Initialise a node.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
self.node_id: Final[str] = node_id
self.name: Final[str] = name
self.node_type: Final[NodeType] = node_type
self.priority = priority
self.hardware_state: HardwareState = hardware_state
self.resetting_count: int = 0
self.config_values: TrainingConfig = config_values
self.booting_count: int = 0
self.shutting_down_count: int = 0
def __repr__(self) -> str:
"""Returns the name of the node."""
return self.name
def turn_on(self) -> None:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self) -> None:
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
self.shutting_down_count = self.config_values.node_shutdown_duration
def reset(self) -> None:
"""Sets the node state to Resetting and starts the reset count."""
self.hardware_state = HardwareState.RESETTING
self.resetting_count = self.config_values.node_reset_duration
def update_resetting_status(self) -> None:
"""Updates the resetting count."""
self.resetting_count -= 1
if self.resetting_count <= 0:
self.resetting_count = 0
self.hardware_state = HardwareState.ON
def update_booting_status(self) -> None:
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self) -> None:
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:
self.shutting_down_count = 0
self.hardware_state = HardwareState.OFF

View File

@@ -0,0 +1,94 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
class NodeStateInstructionGreen(object):
"""The Node State Instruction class."""
def __init__(
self,
_id: str,
_start_step: int,
_end_step: int,
_node_id: str,
_node_pol_type: "NodePOLType",
_service_name: str,
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
) -> None:
"""
Initialise the Node State Instruction.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _node_id: The id of the associated node
:param _node_pol_type: The pattern of life type
:param _service_name: The service name
:param _state: The state (node or service)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type: "NodePOLType" = _node_pol_type
self.service_name: str = _service_name # Not used when not a service instruction
# TODO: confirm type of state
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
def get_start_step(self) -> int:
"""
Gets the start step.
Returns:
The start step
"""
return self.start_step
def get_end_step(self) -> int:
"""
Gets the end step.
Returns:
The end step
"""
return self.end_step
def get_node_id(self) -> str:
"""
Gets the node ID.
Returns:
The node ID
"""
return self.node_id
def get_node_pol_type(self) -> "NodePOLType":
"""
Gets the node pattern of life type (enum).
Returns:
The node pattern of life type (enum)
"""
return self.node_pol_type
def get_service_name(self) -> str:
"""
Gets the service name.
Returns:
The service name
"""
return self.service_name
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
Returns:
The state (node or service)
"""
return self.state

View File

@@ -0,0 +1,143 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
from primaite.common.enums import NodePOLType
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
class NodeStateInstructionRed:
"""The Node State Instruction class."""
def __init__(
self,
_id: str,
_start_step: int,
_end_step: int,
_target_node_id: str,
_pol_initiator: "NodePOLInitiator",
_pol_type: NodePOLType,
pol_protocol: str,
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
_pol_source_node_id: str,
_pol_source_node_service: str,
_pol_source_node_service_state: str,
) -> None:
"""
Initialise the Node State Instruction for the red agent.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _target_node_id: The id of the associated node
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
:param _pol_type: The pattern of life type
:param pol_protocol: The pattern of life protocol/service affected
:param _pol_state: The state (node or service)
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.target_node_id: str = _target_node_id
self.initiator: "NodePOLInitiator" = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name: str = pol_protocol # Not used when not a service instruction
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
self.source_node_id: str = _pol_source_node_id
self.source_node_service: str = _pol_source_node_service
self.source_node_service_state = _pol_source_node_service_state
def get_start_step(self) -> int:
"""
Gets the start step.
Returns:
The start step
"""
return self.start_step
def get_end_step(self) -> int:
"""
Gets the end step.
Returns:
The end step
"""
return self.end_step
def get_target_node_id(self) -> str:
"""
Gets the node ID.
Returns:
The node ID
"""
return self.target_node_id
def get_initiator(self) -> "NodePOLInitiator":
"""
Gets the initiator.
Returns:
The initiator
"""
return self.initiator
def get_pol_type(self) -> NodePOLType:
"""
Gets the node pattern of life type (enum).
Returns:
The node pattern of life type (enum)
"""
return self.pol_type
def get_service_name(self) -> str:
"""
Gets the service name.
Returns:
The service name
"""
return self.service_name
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
Returns:
The state (node or service)
"""
return self.state
def get_source_node_id(self) -> str:
"""
Gets the source node id (used for initiator type SERVICE).
Returns:
The source node id
"""
return self.source_node_id
def get_source_node_service(self) -> str:
"""
Gets the source node service (used for initiator type SERVICE).
Returns:
The source node service
"""
return self.source_node_service
def get_source_node_service_state(self) -> str:
"""
Gets the source node service state (used for initiator type SERVICE).
Returns:
The source node service state
"""
return self.source_node_service_state

View File

@@ -0,0 +1,42 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
class PassiveNode(Node):
"""The Passive Node class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
) -> None:
"""
Initialise a passive node.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
# Pass through to Super for now
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
@property
def ip_address(self) -> str:
"""
Gets the node IP address as an empty string.
No concept of IP address for passive nodes for now.
:return: The node IP address.
"""
return ""

View File

@@ -0,0 +1,190 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class ServiceNode(ActiveNode):
"""ServiceNode class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
ip_address: str,
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
) -> None:
"""
Initialise a Service Node.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(
node_id,
name,
node_type,
priority,
hardware_state,
ip_address,
software_state,
file_system_state,
config_values,
)
self.services: Dict[str, Service] = {}
def add_service(self, service: Service) -> None:
"""
Adds a service to the node.
:param service: The service to add
"""
self.services[service.name] = service
def has_service(self, protocol_name: str) -> bool:
"""
Indicates whether a service is on a node.
:param protocol_name: The service (protocol)e.
:return: True if service (protocol) is on the node, otherwise False.
"""
for service_key, service_value in self.services.items():
if service_key == protocol_name:
return True
return False
def service_running(self, protocol_name: str) -> bool:
"""
Indicates whether a service is in a running state on the node.
:param protocol_name: The service (protocol)
:return: True if service (protocol) is in a running state on the node, otherwise False.
"""
for service_key, service_value in self.services.items():
if service_key == protocol_name:
if service_value.software_state != SoftwareState.PATCHING:
return True
else:
return False
return False
def service_is_overwhelmed(self, protocol_name: str) -> bool:
"""
Indicates whether a service is in an overwhelmed state on the node.
:param protocol_name: The service (protocol)
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
"""
for service_key, service_value in self.services.items():
if service_key == protocol_name:
if service_value.software_state == SoftwareState.OVERWHELMED:
return True
else:
return False
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
:param protocol_name: The service (protocol).
:param software_state: The software_state.
"""
if self.hardware_state != HardwareState.OFF:
service_key = protocol_name
service_value = self.services.get(service_key)
if service_value:
# Can't set to compromised if you're in a patching state
if (
software_state == SoftwareState.COMPROMISED
and service_value.software_state != SoftwareState.PATCHING
) or software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
f"cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.services[<key>]:{protocol_name}, "
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
Done if the software_state is not "compromised".
:param protocol_name: The service (protocol).
:param software_state: The software_state.
"""
if self.hardware_state != HardwareState.OFF:
service_key = protocol_name
service_value = self.services.get(service_key)
if service_value:
if service_value.software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
f"cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.services[<key>]:{protocol_name}, "
f"Node.services[<key>].software_state:{software_state}"
)
def get_service_state(self, protocol_name: str) -> SoftwareState:
"""
Gets the state of a service.
:return: The software_state of the service.
"""
service_key = protocol_name
service_value = self.services.get(service_key)
if service_value:
return service_value.software_state
def update_services_patching_status(self) -> None:
"""Updates the patching counter for any service that are patching."""
for service_key, service_value in self.services.items():
if service_value.software_state == SoftwareState.PATCHING:
service_value.reduce_patching_count()
def update_resetting_status(self) -> None:
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self) -> None:
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD

View File

@@ -0,0 +1,34 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
import importlib.util
import os
import subprocess
import sys
from logging import Logger
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
def start_jupyter_session() -> None:
"""
Starts a new Jupyter notebook session in the app notebooks directory.
Currently only works on Windows OS.
.. todo:: Figure out how to get this working for Linux and MacOS too.
"""
if importlib.util.find_spec("jupyter") is not None:
jupyter_cmd = "python3 -m jupyter lab"
if sys.platform == "win32":
jupyter_cmd = "jupyter lab"
working_dir = os.getcwd()
os.chdir(PRIMAITE_PATHS.user_notebooks_path)
subprocess.Popen(jupyter_cmd)
os.chdir(working_dir)
else:
# Jupyter is not installed
_LOGGER.error("Cannot start jupyter lab as it is not installed")

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Pattern of Life- Represents the actions of users on the network."""

View File

@@ -0,0 +1,264 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE: bool = False
def apply_iers(
network: MultiGraph,
nodes: Dict[str, NodeUnion],
links: Dict[str, Link],
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
) -> None:
"""
Applies IERs to the links (link pattern of life).
Args:
network: The network modelled in the environment
nodes: The nodes within the environment
links: The links within the environment
iers: The IERs to apply to the links
acl: The Access Control List
step: The step number.
"""
if _VERBOSE:
print("Applying IERs")
# Go through each IER and check the conditions for it being applied
# If everything is in place, apply the IER protocol load to the relevant links
for ier_key, ier_value in iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
protocol = ier_value.get_protocol()
port = ier_value.get_port()
load = ier_value.get_load()
source_node_id = ier_value.get_source_node_id()
dest_node_id = ier_value.get_dest_node_id()
# Need to set the running status to false first for all IERs
ier_value.set_is_running(False)
source_valid = True
dest_valid = True
acl_block = False
if step >= start_step and step <= stop_step:
# continue --------------------------
# Get the source and destination node for this link
source_node = nodes[source_node_id]
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if (
source_node.hardware_state == HardwareState.ON
and source_node.software_state != SoftwareState.PATCHING
):
source_valid = True
else:
# IER no longer valid
source_valid = False
elif source_node.node_type == NodeType.ACTUATOR:
# It's an actuator
# TO DO
pass
else:
# It's not a switch or an actuator (so active node)
if (
source_node.hardware_state == HardwareState.ON
and source_node.software_state != SoftwareState.PATCHING
):
if source_node.has_service(protocol):
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
source_valid = True
else:
source_valid = False
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
source_valid = False
else:
# Do nothing - IER no longer valid
source_valid = False
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
dest_valid = True
else:
# IER no longer valid
dest_valid = False
elif dest_node.node_type == NodeType.ACTUATOR:
# It's an actuator
pass
else:
# It's not a switch or an actuator (so active node)
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
if dest_node.has_service(protocol):
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
dest_valid = True
else:
dest_valid = False
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
dest_valid = False
else:
# Do nothing - IER no longer valid
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
"ACL block on source: "
+ source_node.ip_address
+ ", dest: "
+ dest_node.ip_address
+ ", protocol: "
+ protocol
+ ", port: "
+ port
)
else:
if _VERBOSE:
print("No ACL block")
# Check whether both the source and destination are valid, and there's no ACL block
if source_valid and dest_valid and not acl_block:
# Load up the link(s) with the traffic
if _VERBOSE:
print("Source, Dest and ACL valid")
# Get the shortest path (i.e. nodes) between source and destination
path_node_list = shortest_path(network, source_node, dest_node)
path_node_list_length = len(path_node_list)
path_valid = True
# We might have a switch in the path, so check all nodes are operational
for node in path_node_list:
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
path_valid = False
if path_valid:
if _VERBOSE:
print("Applying IER to link(s)")
count = 0
link_capacity_exceeded = False
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
pass
count += 1
# Check whether the link capacity for any links on this path have been exceeded
if link_capacity_exceeded == False:
# Now apply the new loads to the links
count = 0
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
# Add the load from this IER
link.add_protocol_load(protocol, load)
count += 1
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
else:
# One of the nodes is not operational
if _VERBOSE:
print("Path not valid - one or more nodes not operational")
pass
else:
if _VERBOSE:
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
else:
# Do nothing - IER no longer valid
pass
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[str, NodeStateInstructionGreen],
step: int,
) -> None:
"""
Applies node pattern of life.
Args:
nodes: The nodes within the environment
node_pol: The node pattern of life to apply
step: The step number.
"""
if _VERBOSE:
print("Applying Node PoL")
for key, node_instruction in node_pol.items():
start_step = node_instruction.get_start_step()
stop_step = node_instruction.get_end_step()
node_id = node_instruction.get_node_id()
node_pol_type = node_instruction.get_node_pol_type()
service_name = node_instruction.get_service_name()
state = node_instruction.get_state()
if step >= start_step and step <= stop_step:
# continue --------------------------
node = nodes[node_id]
if node_pol_type == NodePOLType.OPERATING:
# Change hardware state
node.hardware_state = state
elif node_pol_type == NodePOLType.OS:
# Change OS state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_software_state_if_not_compromised(state)
elif node_pol_type == NodePOLType.SERVICE:
# Change a service state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ServiceNode):
node.set_service_state_if_not_compromised(service_name, state)
else:
# Change the file system status
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_file_system_state_if_not_compromised(state)
else:
# PoL is not valid in this time step
pass

147
src/primaite/pol/ier.py Normal file
View File

@@ -0,0 +1,147 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""
Information Exchange Requirements for APE.
Used to represent an information flow from source to destination.
"""
class IER(object):
"""Information Exchange Requirement class."""
def __init__(
self,
_id: str,
_start_step: int,
_end_step: int,
_load: int,
_protocol: str,
_port: str,
_source_node_id: str,
_dest_node_id: str,
_mission_criticality: int,
_running: bool = False,
) -> None:
"""
Initialise an Information Exchange Request.
:param _id: The IER id
:param _start_step: The step when this IER should start
:param _end_step: The step when this IER should end
:param _load: The load this IER should put on a link (bps)
:param _protocol: The protocol of this IER
:param _port: The port this IER runs on
:param _source_node_id: The source node ID
:param _dest_node_id: The destination node ID
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.source_node_id: str = _source_node_id
self.dest_node_id: str = _dest_node_id
self.load: int = _load
self.protocol: str = _protocol
self.port: str = _port
self.mission_criticality: int = _mission_criticality
self.running: bool = _running
def get_id(self) -> str:
"""
Gets IER ID.
Returns:
IER ID
"""
return self.id
def get_start_step(self) -> int:
"""
Gets IER start step.
Returns:
IER start step
"""
return self.start_step
def get_end_step(self) -> int:
"""
Gets IER end step.
Returns:
IER end step
"""
return self.end_step
def get_load(self) -> int:
"""
Gets IER load.
Returns:
IER load
"""
return self.load
def get_protocol(self) -> str:
"""
Gets IER protocol.
Returns:
IER protocol
"""
return self.protocol
def get_port(self) -> str:
"""
Gets IER port.
Returns:
IER port
"""
return self.port
def get_source_node_id(self) -> str:
"""
Gets IER source node ID.
Returns:
IER source node ID
"""
return self.source_node_id
def get_dest_node_id(self) -> str:
"""
Gets IER destination node ID.
Returns:
IER destination node ID
"""
return self.dest_node_id
def get_is_running(self) -> bool:
"""
Informs whether the IER is currently running.
Returns:
True if running
"""
return self.running
def set_is_running(self, _value: bool) -> None:
"""
Sets the running state of the IER.
Args:
_value: running status
"""
self.running = _value
def get_mission_criticality(self) -> int:
"""
Gets the IER mission criticality (used in the reward function).
Returns:
Mission criticality value (0 lowest to 5 highest)
"""
return self.mission_criticality

View File

@@ -0,0 +1,353 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_LOGGER = getLogger(__name__)
_VERBOSE: bool = False
def apply_red_agent_iers(
network: MultiGraph,
nodes: Dict[str, NodeUnion],
links: Dict[str, Link],
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
) -> None:
"""
Applies IERs to the links (link POL) resulting from red agent attack.
Args:
network: The network modelled in the environment
nodes: The nodes within the environment
links: The links within the environment
iers: The red agent IERs to apply to the links
acl: The Access Control List
step: The step number.
"""
# Go through each IER and check the conditions for it being applied
# If everything is in place, apply the IER protocol load to the relevant links
for ier_key, ier_value in iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
protocol = ier_value.get_protocol()
port = ier_value.get_port()
load = ier_value.get_load()
source_node_id = ier_value.get_source_node_id()
dest_node_id = ier_value.get_dest_node_id()
# Need to set the running status to false first for all IERs
ier_value.set_is_running(False)
source_valid = True
dest_valid = True
acl_block = False
if step >= start_step and step <= stop_step:
# continue --------------------------
# Get the source and destination node for this link
source_node = nodes[source_node_id]
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if source_node.hardware_state == HardwareState.ON:
source_valid = True
else:
# IER no longer valid
source_valid = False
elif source_node.node_type == NodeType.ACTUATOR:
# It's an actuator
# TO DO
pass
else:
# It's not a switch or an actuator (so active node)
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
# to change according to duck typing.
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
source_valid = True
else:
source_valid = False
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
source_valid = False
else:
# Do nothing - IER no longer valid
source_valid = False
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if dest_node.hardware_state == HardwareState.ON:
dest_valid = True
else:
# IER no longer valid
dest_valid = False
elif dest_node.node_type == NodeType.ACTUATOR:
# It's an actuator
pass
else:
# It's not a switch or an actuator (so active node)
if dest_node.hardware_state == HardwareState.ON:
if dest_node.has_service(protocol):
# We don't care what state the destination service is in for an IER
dest_valid = True
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
dest_valid = False
else:
# Do nothing - IER no longer valid
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
"ACL block on source: "
+ source_node.ip_address
+ ", dest: "
+ dest_node.ip_address
+ ", protocol: "
+ protocol
+ ", port: "
+ port
)
else:
if _VERBOSE:
print("No ACL block")
# Check whether both the source and destination are valid, and there's no ACL block
if source_valid and dest_valid and not acl_block:
# Load up the link(s) with the traffic
if _VERBOSE:
print("Source, Dest and ACL valid")
# Get the shortest path (i.e. nodes) between source and destination
path_node_list = shortest_path(network, source_node, dest_node)
path_node_list_length = len(path_node_list)
path_valid = True
# We might have a switch in the path, so check all nodes are operational
# We're assuming here that red agents can get past switches that are patching
for node in path_node_list:
if node.hardware_state != HardwareState.ON:
path_valid = False
if path_valid:
if _VERBOSE:
print("Applying IER to link(s)")
count = 0
link_capacity_exceeded = False
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
pass
count += 1
# Check whether the link capacity for any links on this path have been exceeded
if link_capacity_exceeded == False:
# Now apply the new loads to the links
count = 0
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
# Add the load from this IER
link.add_protocol_load(protocol, load)
count += 1
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
if _VERBOSE:
print("Red IER was allowed to run in step " + str(step))
else:
# One of the nodes is not operational
if _VERBOSE:
print("Path not valid - one or more nodes not operational")
pass
else:
if _VERBOSE:
print("Red IER was NOT allowed to run in step " + str(step))
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
else:
# Do nothing - IER no longer valid
pass
pass
def apply_red_agent_node_pol(
nodes: Dict[str, NodeUnion],
iers: Dict[str, IER],
node_pol: Dict[str, NodeStateInstructionRed],
step: int,
) -> None:
"""
Applies node pattern of life.
Args:
nodes: The nodes within the environment
iers: The red agent IERs
node_pol: The red agent node pattern of life to apply
step: The step number.
"""
if _VERBOSE:
print("Applying Node Red Agent PoL")
for key, node_instruction in node_pol.items():
start_step = node_instruction.get_start_step()
stop_step = node_instruction.get_end_step()
target_node_id = node_instruction.get_target_node_id()
initiator = node_instruction.get_initiator()
pol_type = node_instruction.get_pol_type()
service_name = node_instruction.get_service_name()
state = node_instruction.get_state()
source_node_id = node_instruction.get_source_node_id()
source_node_service_name = node_instruction.get_source_node_service()
source_node_service_state_value = node_instruction.get_source_node_service_state()
passed_checks = False
if step >= start_step and step <= stop_step:
# continue --------------------------
target_node: NodeUnion = nodes[target_node_id]
# check if the initiator type is a str, and if so, cast it as
# NodePOLInitiator
if isinstance(initiator, str):
initiator = NodePOLInitiator[initiator]
# Based the action taken on the initiator type
if initiator == NodePOLInitiator.DIRECT:
# No conditions required, just apply the change
passed_checks = True
elif initiator == NodePOLInitiator.IER:
# Need to check there is a red IER incoming
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
elif initiator == NodePOLInitiator.SERVICE:
# Need to check the condition of a service on another node
source_node = nodes[source_node_id]
if source_node.has_service(source_node_service_name):
if (
source_node.get_service_state(source_node_service_name)
== SoftwareState[source_node_service_state_value]
):
passed_checks = True
else:
# Do nothing, no matching state value
pass
else:
# Do nothing, service not on this node
pass
else:
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
# Only apply the PoL if the checks have passed (based on the initiator type)
if passed_checks:
# Apply the change
if pol_type == NodePOLType.OPERATING:
# Change hardware state
target_node.hardware_state = state
elif pol_type == NodePOLType.OS:
# Change OS state
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.software_state = state
elif pol_type == NodePOLType.SERVICE:
# Change a service state
if isinstance(target_node, ServiceNode):
target_node.set_service_state(service_name, state)
else:
# Change the file system status
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass
def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool:
"""Checks if the RED IER is incoming.
:param node: Destination node of the IER
:type node: NodeUnion
:param iers: Directory of IERs
:type iers: Dict[str,IER]
:param node_pol_type: Type of Pattern-Of-Life
:type node_pol_type: NodePOLType
:return: Whether the RED IER is incoming.
:rtype: bool
"""
node_id = node.node_id
for ier_key, ier_value in iers.items():
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
if (
node_pol_type == NodePOLType.OPERATING
or node_pol_type == NodePOLType.OS
or node_pol_type == NodePOLType.FILE
):
# It's looking to change hardware state, file system or SoftwareState, so valid
return True
elif node_pol_type == NodePOLType.SERVICE:
# Check if the service is present on the node and running
ier_protocol = ier_value.get_protocol()
if isinstance(node, ServiceNode):
if node.has_service(ier_protocol):
if node.service_running(ier_protocol):
# Matching service is present and running, so valid
return True
else:
# Service is present, but not running
return False
else:
# Service is not present
return False
else:
# Not a service node
return False
else:
# Shouldn't get here - instruction type is undefined
return False
else:
# The IER destination is not this node, or the IER is not running
return False

View File

@@ -0,0 +1,225 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, Final, Optional, Tuple, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.utils.session_metadata_parser import parse_session_metadata
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
_LOGGER = getLogger(__name__)
class PrimaiteSession:
"""
The PrimaiteSession class.
Provides a single learning and evaluation entry point for all training and lay down configurations.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
The PrimaiteSession constructor.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
self.legacy_training_config = legacy_training_config
self.legacy_lay_down_config = legacy_lay_down_config
# check if session path is provided
if session_path is not None:
# set load_session to true
self.is_load_session = True
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path, legacy_training_config
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa
def setup(self) -> None:
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
else:
# Invalid AgentFramework AgentIdentifier combo
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(
self._training_config_path,
self._lay_down_config_path,
self.session_path,
self.legacy_training_config,
self.legacy_lay_down_config,
)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
else:
# Invalid AgentFramework
raise ValueError
self.session_path: Path = self._agent_session.session_path
self.timestamp_str: str = self._agent_session.timestamp_str
self.learning_path: Path = self._agent_session.learning_path
self.evaluation_path: Path = self._agent_session.evaluation_path
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(**kwargs)
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
return json.load(file)

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utilities to prepare the user's data folders."""

View File

@@ -0,0 +1,22 @@
# The main PrimAITE application config file
# Logging
logging:
log_level: INFO
logger_format:
DEBUG: '%(asctime)s: %(message)s'
INFO: '%(asctime)s: %(message)s'
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
# Session
session:
outputs:
plots:
size:
auto_size: false
width: 1500
height: 900
template: plotly_white
range_slider: false

View File

@@ -0,0 +1,14 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from primaite import getLogger
_LOGGER = getLogger(__name__)
def run() -> None:
"""Perform the full clean-up."""
pass
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,35 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import filecmp
import os
import shutil
from logging import Logger
from pathlib import Path
import pkg_resources
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
def run(overwrite_existing: bool = True) -> None:
"""
Resets the demo jupyter notebooks in the users app notebooks directory.
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
"""
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
for subdir, dirs, files in os.walk(notebooks_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
target_fp = PRIMAITE_PATHS.user_notebooks_path / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()
if overwrite_existing and not copy_file:
copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp))
if copy_file:
shutil.copy2(fp, target_fp)
_LOGGER.info(f"Reset example notebook: {target_fp}")

View File

@@ -0,0 +1,35 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import filecmp
import os
import shutil
from pathlib import Path
import pkg_resources
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER = getLogger(__name__)
def run(overwrite_existing: bool = True) -> None:
"""
Resets the example config files in the users app config directory.
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
"""
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
for subdir, dirs, files in os.walk(configs_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
target_fp = PRIMAITE_PATHS.user_config_path / "example_config" / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()
if overwrite_existing and not copy_file:
copy_file = not filecmp.cmp(fp, target_fp)
if copy_file:
shutil.copy2(fp, target_fp)
_LOGGER.info(f"Reset example config: {target_fp}")

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Record data of the system's state and agent's observations and actions."""

View File

@@ -0,0 +1,102 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Transaction class."""
from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from primaite.common.enums import AgentIdentifier
if TYPE_CHECKING:
import numpy as np
from gym import spaces
class Transaction(object):
"""Transaction class."""
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
"""
Transaction constructor.
:param agent_identifier: An identifier for the agent in use
:param episode_number: The episode number
:param step_number: The step number
"""
self.timestamp: datetime = datetime.now()
"The datetime of the transaction"
self.agent_identifier: AgentIdentifier = agent_identifier
"The agent identifier"
self.episode_number: int = episode_number
"The episode number"
self.step_number: int = step_number
"The step number"
self.obs_space: "spaces.Space" = None
"The observation space (pre)"
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space before any actions are taken"
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
"The observation space after any actions are taken"
self.reward: Optional[float] = None
"The reward value"
self.action_space: Optional[int] = None
"The action space invoked by the agent"
self.obs_space_description: Optional[List[str]] = None
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]:
"""
Converts the Transaction to a csv data row and provides a header.
:return: A tuple consisting of (header, data).
"""
if isinstance(self.action_space, int):
action_length = self.action_space
else:
action_length = self.action_space.size
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + self.obs_space_description
row = [
str(self.timestamp),
str(self.episode_number),
str(self.step_number),
str(self.reward),
]
row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist()
return header, row
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
"""
Turns action space into a string array so it can be saved to csv.
:param action_space: The action space
:return: The action space as an array of strings
"""
if isinstance(action_space, list):
return [str(i) for i in action_space]
else:
return [str(action_space)]
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
"""
Turns observation space into a string array so it can be saved to csv.
:param obs_space: The observation space
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
:param obs_features: The number of features associated with the asset
:return: The observation space as an array of strings
"""
return_array = []
for x in range(obs_assets):
for y in range(obs_features):
return_array.append(str(obs_space[x][y]))
return return_array

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utilities for PrimAITE."""

Some files were not shown because too many files have changed in this diff Show More