Compare commits
492 Commits
v3.0.0
...
main_backu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fb88c94e8 | ||
|
|
44de1531a5 | ||
|
|
5e526ba81e | ||
|
|
506b0836ea | ||
|
|
c2992c3348 | ||
|
|
aaf1b16912 | ||
|
|
6bf348e5c3 | ||
|
|
4ed31a52e5 | ||
|
|
efb230a5fc | ||
|
|
da8f91c78c | ||
|
|
3fae05d971 | ||
|
|
5893ea8db4 | ||
|
|
10ce2923c7 | ||
|
|
cd7ba9986c | ||
|
|
1084be914b | ||
|
|
f01825b180 | ||
|
|
5bf9f7f4ea | ||
|
|
f71b3480f0 | ||
|
|
87741ba994 | ||
|
|
0a1d17c9cc | ||
|
|
e647b35f6f | ||
|
|
b15be9796d | ||
|
|
b40fb09c1f | ||
|
|
99ff8ca4e1 | ||
|
|
ebb901c2b2 | ||
|
|
094c2380cd | ||
|
|
1b8af0d862 | ||
|
|
0b93ca41ab | ||
|
|
69b0ea4572 | ||
|
|
0fc596e06b | ||
|
|
a5fa613bea | ||
|
|
46699880ce | ||
|
|
22f72139e3 | ||
|
|
eec1c25989 | ||
|
|
514f239cc6 | ||
|
|
21931e991c | ||
|
|
0c601d0383 | ||
|
|
8e3a0a0afa | ||
|
|
7fe5df7fc4 | ||
|
|
8bf1440f9b | ||
|
|
e9b52a69b7 | ||
|
|
4f16105b67 | ||
|
|
7e0f55cdb8 | ||
|
|
31955c0c84 | ||
|
|
ef6585a298 | ||
|
|
5d9dd7a2d9 | ||
|
|
a39541623a | ||
|
|
63297ef0ed | ||
|
|
4527b38aa6 | ||
|
|
050ca68907 | ||
|
|
cda0a28c03 | ||
|
|
b8ca5f1dca | ||
|
|
196d8855c3 | ||
|
|
e1a396981a | ||
|
|
10c8604159 | ||
|
|
722fe97c84 | ||
|
|
e62bee3052 | ||
|
|
7999eb56a5 | ||
|
|
21598fd792 | ||
|
|
df52236a7d | ||
|
|
470f52f35e | ||
|
|
5475155686 | ||
|
|
0a6078df65 | ||
|
|
fabbde9641 | ||
|
|
69e7b23d2c | ||
|
|
ae8afbdcdc | ||
|
|
6ce816f2e1 | ||
|
|
f840e924e3 | ||
|
|
3bbc7b8615 | ||
|
|
f6a9063484 | ||
|
|
994546046e | ||
|
|
90620e0c64 | ||
|
|
3731b2ba13 | ||
|
|
dd9613853b | ||
|
|
f22681d6b4 | ||
|
|
e9fc9a0d1a | ||
|
|
1e24ce7b9a | ||
|
|
afbe2e1400 | ||
|
|
f8959a65e9 | ||
|
|
ba6f8f054b | ||
|
|
52a7185583 | ||
|
|
d730a206d8 | ||
|
|
713aa279ec | ||
|
|
e070b247b1 | ||
|
|
a10a1d9267 | ||
|
|
9e3285350a | ||
|
|
c5f612889e | ||
|
|
9c6ee73b9e | ||
|
|
a2ef4328dd | ||
|
|
393505b98b | ||
|
|
e198c17ac0 | ||
|
|
a7a5fb8598 | ||
|
|
3d0e50823a | ||
|
|
15f37c938f | ||
|
|
9520cfea24 | ||
|
|
81295a4fc4 | ||
|
|
a2f43b5abc | ||
|
|
95b6211781 | ||
|
|
3aab6a3738 | ||
|
|
1721f2eb84 | ||
|
|
2d1a1e6db7 | ||
|
|
35af1e9d1e | ||
|
|
e1ac628793 | ||
|
|
bfce2f9a7b | ||
|
|
257be9532f | ||
|
|
57157db08c | ||
|
|
39b30460cd | ||
|
|
dd21f9440f | ||
|
|
a432822bcb | ||
|
|
ec938ce761 | ||
|
|
007a0c4b98 | ||
|
|
2526427f2f | ||
|
|
cc09fe9079 | ||
|
|
78d7f39342 | ||
|
|
da20c0e9e6 | ||
|
|
360eb38c2b | ||
|
|
6b76214eb2 | ||
|
|
75c91b9eb9 | ||
|
|
432da5ca90 | ||
|
|
9a0b14b111 | ||
|
|
707d8f6189 | ||
|
|
50697c6f75 | ||
|
|
9df8d132fc | ||
|
|
2bb71623fa | ||
|
|
ea7c1519fe | ||
|
|
8aa71c3ff8 | ||
|
|
7c2ff55da2 | ||
|
|
1b6244d13f | ||
|
|
661c865108 | ||
|
|
eb75d15722 | ||
|
|
31fedb945e | ||
|
|
4e53564670 | ||
|
|
8e2f105d57 | ||
|
|
f9c7cafe87 | ||
|
|
e743b2380c | ||
|
|
c2931bde6c | ||
|
|
436448beed | ||
|
|
7b929109dc | ||
|
|
118b05ede0 | ||
|
|
9650669c83 | ||
|
|
558223e8b6 | ||
|
|
77f717c649 | ||
|
|
738e5b5dca | ||
|
|
606354614a | ||
|
|
36e48dc8e9 | ||
|
|
0ab4dab72a | ||
|
|
f8cb18c654 | ||
|
|
fd2ab39edf | ||
|
|
f5e1ef7491 | ||
|
|
f4a70394e0 | ||
|
|
c61770825a | ||
|
|
a19fbd1e98 | ||
|
|
85c360548b | ||
|
|
06c20f6984 | ||
|
|
96b48aad79 | ||
|
|
f817efdc69 | ||
|
|
c7547f715e | ||
|
|
d4469f5226 | ||
|
|
3c20764096 | ||
|
|
11defda955 | ||
|
|
5b3663c3cf | ||
|
|
baa14b6cd7 | ||
|
|
79724d6884 | ||
|
|
30d8478a78 | ||
|
|
0ec2f79ac3 | ||
|
|
0c63d197e5 | ||
|
|
585d35338f | ||
|
|
f3750032be | ||
|
|
350b3db3f6 | ||
|
|
6a888d2efe | ||
|
|
5f6bc32b98 | ||
|
|
6b59ce960d | ||
|
|
9e936513d5 | ||
|
|
dc26863216 | ||
|
|
56fd9c4d0a | ||
|
|
1633900ce7 | ||
|
|
6c7ec62166 | ||
|
|
a07ce00852 | ||
|
|
dcf5bfddfa | ||
|
|
a303e9096a | ||
|
|
81a8058836 | ||
|
|
c641f67914 | ||
|
|
7f64d06ad4 | ||
|
|
c8191e60ba | ||
|
|
d555584e90 | ||
|
|
548ecf8e08 | ||
|
|
d8cfbc1042 | ||
|
|
831469d01c | ||
|
|
19a9cef130 | ||
|
|
5939fda2ba | ||
|
|
ecc06a5db0 | ||
|
|
30bcdba429 | ||
|
|
563ff72fd6 | ||
|
|
ca737e080f | ||
|
|
921dc934c2 | ||
|
|
5ec8d3c8c1 | ||
|
|
43a4f93626 | ||
|
|
3c9b8a272a | ||
|
|
e3ad1470df | ||
|
|
bd6f9fc309 | ||
|
|
47d7e9f3f6 | ||
|
|
0145532103 | ||
|
|
91287f8666 | ||
|
|
605a5b4cd6 | ||
|
|
17894376c6 | ||
|
|
9d49406df6 | ||
|
|
23adc740cd | ||
|
|
41fab6562e | ||
|
|
752a611b89 | ||
|
|
677d12b550 | ||
|
|
40381833d3 | ||
|
|
35b481a2f3 | ||
|
|
d49f73f139 | ||
|
|
d7bf678b1f | ||
|
|
e03c29b921 | ||
|
|
bbb305d561 | ||
|
|
4ef7831bfa | ||
|
|
7e0eee5d73 | ||
|
|
f4b98542b6 | ||
|
|
036e0fe342 | ||
|
|
04e52453b1 | ||
|
|
207601b81f | ||
|
|
3a75ed8ccc | ||
|
|
86725064ec | ||
|
|
2a08d3a2a5 | ||
|
|
82a5122276 | ||
|
|
4c03aaee24 | ||
|
|
1ade92f55c | ||
|
|
c9f4741655 | ||
|
|
82d7c168fe | ||
|
|
159d47fd6c | ||
|
|
46b44f9e23 | ||
|
|
3b91a99070 | ||
|
|
c5d7d55747 | ||
|
|
99f1f7cfc1 | ||
|
|
3438ce7e09 | ||
|
|
4371ca13fc | ||
|
|
f651937759 | ||
|
|
e174db5d9e | ||
|
|
87bdaa1ec3 | ||
|
|
c38dda34b9 | ||
|
|
8faf9d70a0 | ||
|
|
b426d5802e | ||
|
|
5c167293e3 | ||
|
|
0ae7158859 | ||
|
|
713225b432 | ||
|
|
7482aead76 | ||
|
|
f62b2aef1c | ||
|
|
171b5cb58e | ||
|
|
b3d4eb4ec0 | ||
|
|
075b11aeca | ||
|
|
f121b0e21c | ||
|
|
940f37bfc6 | ||
|
|
38a3666e8e | ||
|
|
ea01e2209b | ||
|
|
3ced1a1913 | ||
|
|
cda9819e72 | ||
|
|
eac79e0941 | ||
|
|
3f440c0a28 | ||
|
|
9001510fe7 | ||
|
|
0756e61e5d | ||
|
|
7bdcee5c46 | ||
|
|
d41e2ad590 | ||
|
|
5e270c7673 | ||
|
|
3de6208915 | ||
|
|
3abe39aa10 | ||
|
|
410afc1d40 | ||
|
|
e199dc52c0 | ||
|
|
34b294f89a | ||
|
|
410d5abe12 | ||
|
|
820f436f8e | ||
|
|
7816e94f83 | ||
|
|
dffa612ec8 | ||
|
|
4b5cf12aa3 | ||
|
|
7ddedfcc57 | ||
|
|
a883e45bbf | ||
|
|
8ab936fcdc | ||
|
|
d2764d53cc | ||
|
|
12c18adeb1 | ||
|
|
178bd4dc7f | ||
|
|
f47dd8bf61 | ||
|
|
dc4c2c8854 | ||
|
|
8101f49a21 | ||
|
|
63a4c1119b | ||
|
|
94ca28a85f | ||
|
|
cb9d40579f | ||
|
|
0943e9511b | ||
|
|
c3ec33e4df | ||
|
|
123ec8343c | ||
|
|
c38c13b829 | ||
|
|
6c4a538b41 | ||
|
|
ae56827bae | ||
|
|
4299170ce4 | ||
|
|
4f0f542570 | ||
|
|
ee94993344 | ||
|
|
41aed12f27 | ||
|
|
ccad245e6f | ||
|
|
16534237e0 | ||
|
|
27ca53878a | ||
|
|
605ff98a24 | ||
|
|
975ebd6de2 | ||
|
|
203cc98494 | ||
|
|
32d5889b11 | ||
|
|
2a8d28cba6 | ||
|
|
3e691b4f46 | ||
|
|
d5402cdce8 | ||
|
|
c3c4512544 | ||
|
|
73015802ec | ||
|
|
c77fde3dd3 | ||
|
|
f61d50a96f | ||
|
|
a2e02c3cfd | ||
|
|
7f912df383 | ||
|
|
1d3778f400 | ||
|
|
7482192046 | ||
|
|
9666b92caa | ||
|
|
498e6a7ac1 | ||
|
|
02f982afa8 | ||
|
|
b8a4ede83f | ||
|
|
8a1c0b2db7 | ||
|
|
cfeb1c6530 | ||
|
|
746f878747 | ||
|
|
a8c27ec975 | ||
|
|
cffdcdc0d2 | ||
|
|
8f2fd77634 | ||
|
|
301e8b6983 | ||
|
|
cf2f9788ec | ||
|
|
3adb02118c | ||
|
|
185dbb7f02 | ||
|
|
be7d0e1745 | ||
|
|
0bff2d2f36 | ||
|
|
79ecb8e0b9 | ||
|
|
09412cb43d | ||
|
|
ebc0a28460 | ||
|
|
ef4d2c6cdd | ||
|
|
e2d6abf833 | ||
|
|
feead2cd44 | ||
|
|
fb50b8becf | ||
|
|
e0f3d61f65 | ||
|
|
7f1c4ce036 | ||
|
|
5a6fdf58d4 | ||
|
|
a2cc4233b5 | ||
|
|
df42a791c9 | ||
|
|
1a5bd3af48 | ||
|
|
db67a829d5 | ||
|
|
0ab4520904 | ||
|
|
03ae4884e0 | ||
|
|
23bafde457 | ||
|
|
c2c396052f | ||
|
|
6849939265 | ||
|
|
c6a947fbaf | ||
|
|
5b59642695 | ||
|
|
fe102dff6f | ||
|
|
cf64990cff | ||
|
|
eb3368edd6 | ||
|
|
cdd7183d85 | ||
|
|
9b0e24c27b | ||
|
|
785409e12a | ||
|
|
a08ec8844a | ||
|
|
eac17b6e16 | ||
|
|
8f86bda4d2 | ||
|
|
3c8a8188fb | ||
|
|
29d1566789 | ||
|
|
c5175c500e | ||
|
|
605737cd5f | ||
|
|
747ea9d0c6 | ||
|
|
f5e195604f | ||
|
|
29ba64462a | ||
|
|
afc133cbc5 | ||
|
|
0dbd89e5cb | ||
|
|
af4e71db9b | ||
|
|
7382ed26b3 | ||
|
|
fd3b304373 | ||
|
|
9b4ed1199b | ||
|
|
647ba2fcc1 | ||
|
|
64bf4bf58a | ||
|
|
b917b65d49 | ||
|
|
6d502045cb | ||
|
|
de86c85b23 | ||
|
|
1809cbe1f4 | ||
|
|
61bd70a6c9 | ||
|
|
0795a7b4f8 | ||
|
|
02e37e5096 | ||
|
|
273876873e | ||
|
|
9417cd85ab | ||
|
|
6f3e40e390 | ||
|
|
89cea9289b | ||
|
|
038abb9be7 | ||
|
|
709fbc500e | ||
|
|
57b982eea3 | ||
|
|
ef3cef530b | ||
|
|
6cc9516744 | ||
|
|
bfd19280d5 | ||
|
|
2eff3912fb | ||
|
|
69c5c9458b | ||
|
|
af44b99b6f | ||
|
|
c969bc32f5 | ||
|
|
a987ffb745 | ||
|
|
babd4eb5f8 | ||
|
|
d922d4d054 | ||
|
|
940013f9a6 | ||
|
|
e15c8c8c89 | ||
|
|
dcab4b0d4a | ||
|
|
8558ca1020 | ||
|
|
17d036302f | ||
|
|
49707b0a17 | ||
|
|
e52dfababc | ||
|
|
55f13ae654 | ||
|
|
8b61fbebe4 | ||
|
|
a48b217cf3 | ||
|
|
051cd7da2b | ||
|
|
e5b60c2f95 | ||
|
|
cdd710d672 | ||
|
|
1ee6a37188 | ||
|
|
9d868c5090 | ||
|
|
25ec0d93a9 | ||
|
|
2330a30021 | ||
|
|
f37b943f7e | ||
|
|
2c95087056 | ||
|
|
d854773e84 | ||
|
|
b6ce1cbae9 | ||
|
|
875562c385 | ||
|
|
85c102cfc1 | ||
|
|
484a31d082 | ||
|
|
c0b214612a | ||
|
|
3e208bad9b | ||
|
|
7041b79d2a | ||
|
|
2b25573378 | ||
|
|
8efa0295df | ||
|
|
46352ff9c2 | ||
|
|
c276a31b9c | ||
|
|
c904334c83 | ||
|
|
3b0d05e9c9 | ||
|
|
37d606eda6 | ||
|
|
bfd20b7a6b | ||
|
|
a0960555fc | ||
|
|
76ec9683cb | ||
|
|
4ee77656be | ||
|
|
6e58c01e8d | ||
|
|
81e9ddca9b | ||
|
|
a8cc50a495 | ||
|
|
9a231821ea | ||
|
|
d8cd96100e | ||
|
|
31b5031808 | ||
|
|
5906ed7e39 | ||
|
|
c6bb855456 | ||
|
|
2260cb1668 | ||
|
|
65f2d6202f | ||
|
|
733025bd53 | ||
|
|
c6db98c1c2 | ||
|
|
fbb26bbc63 | ||
|
|
5ea77f3e75 | ||
|
|
83694fe537 | ||
|
|
045e074d0f | ||
|
|
6507529db3 | ||
|
|
fa44dd1a26 | ||
|
|
0227769c34 | ||
|
|
375e20a67b | ||
|
|
2724838cf8 | ||
|
|
91dec9e83d | ||
|
|
0483eeca82 | ||
|
|
77a6fd6aff | ||
|
|
8a24427bf7 | ||
|
|
dc011a489c | ||
|
|
9d3d8d5945 | ||
|
|
b255f557db | ||
|
|
05ebd15053 | ||
|
|
6245ad9298 | ||
|
|
3ac2399115 | ||
|
|
182bf177a3 | ||
|
|
d3aa69757b | ||
|
|
56bce1431b | ||
|
|
fa0e836f65 | ||
|
|
e2cc1cb28a | ||
|
|
1d0fd04393 | ||
|
|
ddb6adae2b | ||
|
|
4cabc8a87a | ||
|
|
04c27cc7d5 | ||
|
|
057fb44061 | ||
|
|
51c72aa5be | ||
|
|
769256f0a5 | ||
|
|
7bbdbd6997 | ||
|
|
95a0669e5c | ||
|
|
71f33ed44e | ||
|
|
18f89faf03 | ||
|
|
9bd7aade43 | ||
|
|
754b16c8c8 | ||
|
|
e473c710a2 | ||
|
|
39da5bbe01 | ||
|
|
027709d1e8 | ||
|
|
959b43743c | ||
|
|
43a2b1fa3c | ||
|
|
8fc0316253 |
13
.flake8
Normal file
13
.flake8
Normal file
@@ -0,0 +1,13 @@
|
||||
[flake8]
|
||||
max-line-length=120
|
||||
extend-ignore =
|
||||
D105
|
||||
D107
|
||||
D100
|
||||
D104
|
||||
E203
|
||||
E712
|
||||
D401
|
||||
F811
|
||||
exclude =
|
||||
docs/source/*
|
||||
41
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
41
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: "[BUG] - <bug title goes here>"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
### Describe the bug:
|
||||
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
### To Reproduce:
|
||||
|
||||
Steps to reproduce the behaviour:
|
||||
|
||||
1. Import '...'
|
||||
2. Instantiate '....'
|
||||
3. Pass to '....'
|
||||
4. Run '....'
|
||||
5. See error
|
||||
|
||||
### Expected behaviour
|
||||
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
### Screenshots/Outputs
|
||||
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
### Environment (please complete the following information)
|
||||
|
||||
- **OS:** [e.g. Ubuntu 22.04]
|
||||
- **Python:** [e.g. 3.10.11]
|
||||
- **PrimAITE Version:** [e.g. v2.0.0]
|
||||
- **Software:** [e.g. cli, Jupyter, PyCharm, VSCode etc.]
|
||||
|
||||
### Additional context
|
||||
|
||||
Add any other context about the problem here.
|
||||
24
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
24
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: "[REQUEST] - <request title goes here>"
|
||||
labels: feature_request
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
### Is your feature request related to a problem?
|
||||
|
||||
If so, please give a concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
### Describe the solution you'd like:
|
||||
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
### Describe alternatives you've considered:
|
||||
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
### Additional context:
|
||||
|
||||
Add any other context or screenshots about the feature request here.
|
||||
60
.github/workflows/build-sphinx.yml
vendored
Normal file
60
.github/workflows/build-sphinx.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: build-sphinx-to-github-pages
|
||||
|
||||
env:
|
||||
GITHUB_ACTOR: Autonomous-Resilient-Cyber-Defence
|
||||
GITHUB_REPOSITORY: Autonomous-Resilient-Cyber-Defence/PrimAITE
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN}}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install python dev
|
||||
run: |
|
||||
set -x
|
||||
sudo apt-get update
|
||||
sudo add-apt-repository ppa:deadsnakes/ppa -y
|
||||
sudo apt install python${{ matrix.python-version}}-dev -y
|
||||
|
||||
- name: Install Git
|
||||
run: |
|
||||
set -x
|
||||
sudo apt-get install -y git
|
||||
shell: bash
|
||||
|
||||
- name: Set pip, wheel, setuptools versions
|
||||
run: |
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
pip install setuptools==66 --upgrade
|
||||
pip install build
|
||||
|
||||
- name: Install PrimAITE for docs autosummary
|
||||
run: |
|
||||
set -x
|
||||
python -m pip install -e .[dev]
|
||||
|
||||
- name: Run build script for Sphinx pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
run: |
|
||||
set -x
|
||||
bash $PWD/docs/build-sphinx-docs-to-github-pages.sh
|
||||
66
.github/workflows/python-package.yml
vendored
Normal file
66
.github/workflows/python-package.yml
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
- dev-gui
|
||||
- 'release/**'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
- dev-gui
|
||||
- 'release/**'
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install python dev
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo add-apt-repository ppa:deadsnakes/ppa -y
|
||||
sudo apt install python${{ matrix.python-version}}-dev -y
|
||||
|
||||
- name: Install Build Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
pip install setuptools==66 --upgrade
|
||||
pip install build
|
||||
|
||||
- name: Build PrimAITE
|
||||
run: |
|
||||
python -m build
|
||||
|
||||
- name: Install PrimAITE
|
||||
run: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
|
||||
- name: Perform PrimAITE Setup
|
||||
run: |
|
||||
primaite setup
|
||||
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings.
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest tests/
|
||||
152
.gitignore
vendored
Normal file
152
.gitignore
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
tests/assets/**/*.png
|
||||
tests/assets/**/tensorboard_logs/
|
||||
tests/assets/**/checkpoints/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/source/_autosummary
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
docs/source/primaite-dependencies.rst
|
||||
|
||||
# outputs
|
||||
src/primaite/outputs/
|
||||
|
||||
# benchmark session outputs
|
||||
benchmark/output
|
||||
29
.pre-commit-config.yaml
Normal file
29
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
repos:
|
||||
- repo: http://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: mixed-line-ending
|
||||
- id: requirements-txt-fixer
|
||||
- repo: http://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [ "--line-length=120" ]
|
||||
additional_dependencies:
|
||||
- jupyter
|
||||
- repo: http://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [ "--profile", "black" ]
|
||||
- repo: http://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies:
|
||||
- flake8-docstrings
|
||||
90
CHANGELOG.md
Normal file
90
CHANGELOG.md
Normal file
@@ -0,0 +1,90 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [2.0.0] - 2023-07-26
|
||||
|
||||
### Added
|
||||
- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE.
|
||||
- Application Directories to enable PrimAITE as a Python package with predefined directories for storage.
|
||||
- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib.
|
||||
- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`.
|
||||
- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options.
|
||||
- Session loading to revisit previously run sessions for SB3 Agents.
|
||||
- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface.
|
||||
- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs:
|
||||
1. Session Metadata
|
||||
2. Results
|
||||
3. Diagrams
|
||||
4. Saved agents (training checkpoints and a final trained agent).
|
||||
- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup.
|
||||
- Benchmarking of PrimAITE performance, showcasing session and step durations for reference.
|
||||
- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`.
|
||||
|
||||
### Changed
|
||||
- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions.
|
||||
- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`.
|
||||
- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names.
|
||||
- Docs and Tests now sit outside the `src/` directory.
|
||||
- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages.
|
||||
- All dependencies are now defined in the `pyproject.toml` file.
|
||||
- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each.
|
||||
- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management.
|
||||
- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations.
|
||||
- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority.
|
||||
|
||||
|
||||
### Fixed
|
||||
- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation.
|
||||
- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes.
|
||||
- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands.
|
||||
- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making.
|
||||
|
||||
## [1.1.1] - 2023-06-27
|
||||
|
||||
### Bug Fixes
|
||||
* Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by:
|
||||
* Implementing a reference copy of all green IERs (`self.green_iers_reference`).
|
||||
* Clearing the traffic on reference IERs at the same time as the live IERs.
|
||||
* Passing the `green_iers_reference` to the `apply_iers` function at the reference stage.
|
||||
* Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`.
|
||||
* Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment.
|
||||
* Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes.
|
||||
* Removing the unnecessary "Reapply PoL and IERs" action from the step function.
|
||||
* Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function.
|
||||
|
||||
## [1.1.0] - 2023-03-13
|
||||
|
||||
### Added
|
||||
* The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function.
|
||||
* The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name.
|
||||
* Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components:
|
||||
* Blue agent observation space;
|
||||
* Blue agent action space;
|
||||
* Reward function;
|
||||
* Node pattern-of-life.
|
||||
* The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network).
|
||||
* New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile.
|
||||
* NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability.
|
||||
* The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value.
|
||||
* The User Guide has been updated to reflect all the above changes.
|
||||
|
||||
### Changed
|
||||
* "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing.
|
||||
* "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
|
||||
* "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
|
||||
* "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file.
|
||||
* Updates to Transactions.
|
||||
|
||||
### Fixed
|
||||
* Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default.
|
||||
|
||||
|
||||
|
||||
[unreleased]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/compare/v2.0.0...HEAD
|
||||
[2.0.0]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/tag/v2.0.0
|
||||
39
CONTRIBUTING.md
Normal file
39
CONTRIBUTING.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# How to contribute to PrimAITE?
|
||||
|
||||
|
||||
### **Did you find a bug?**
|
||||
|
||||
|
||||
* **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues).
|
||||
* If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D+-+%3Cbug+title+goes+here%3E). Be sure to follow our bug report template with the headers **Describe the bug**, **To Reproduce**, **Expected behaviour**, **Screenshots/Outputs**, **Environment**, and **Additional context**
|
||||
|
||||
|
||||
### **Do you have a solution to fix the bug?**
|
||||
|
||||
* [Fork the repository](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/fork).
|
||||
* Install the pre-commit hook with `pre-commit install`.
|
||||
* Implement the bug fix.
|
||||
* Update documentation where applicable.
|
||||
* Update the **UNRELEASED** section of the [CHANGELOG.md](CHANGELOG.md) file
|
||||
* Write a suitable test/tests.
|
||||
* Commit the bug fix to the dev branch on your fork. If the bug has an open issue under [Issues](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues), reference the issue in the commit message (e.g. #1 references issue 1).
|
||||
* Submit a pull request from your dev branch to the Autonomous-Resilient-Cyber-Defence/PrimAITE dev branch. Again, if the bug has an open issue under [Issues](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues), reference the issue in the pull request description.
|
||||
|
||||
### **Did you fix whitespace, format code, or make a purely cosmetic patch?**
|
||||
|
||||
Changes that are cosmetic in nature and do not add anything substantial to the stability, functionality, or testability of PrimAITE will generally not be accepted.
|
||||
|
||||
### **Do you intend to add a new feature or change an existing one?**
|
||||
|
||||
* Submit a [feature request issue](https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues/new?assignees=&labels=feature_request&projects=&template=feature_request.md&title=%5BREQUEST%5D+-+%3Crequest+title+goes+here%3E).
|
||||
* Know how to implement the new feature or change? Follow the same steps in the bug fix section above to fork, build, document, test, commit, and submit a pull request.
|
||||
|
||||
### **Do you have questions about the source code?**
|
||||
|
||||
Ask any question about how to use PrimAITE in our discussions section.
|
||||
|
||||
### **Do you want to contribute to the PrimAITE documentation?**
|
||||
|
||||
Please follow the "Do you intend to add a new feature or change an existing one?" section above and tag your feature request issue and pull request with the documentation tag.
|
||||
|
||||
Thank you from the PrimAITE dev team! 🙌
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 - 2025 Defence Science and Technology Laboratory UK (https://dstl.gov.uk)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
2
MANIFEST.in
Normal file
2
MANIFEST.in
Normal file
@@ -0,0 +1,2 @@
|
||||
include src/primaite/setup/_package_data/primaite_config.yaml
|
||||
include src/primaite/config/_package_data/*.yaml
|
||||
BIN
PrimAITE_logo_transparent.png
Normal file
BIN
PrimAITE_logo_transparent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 274 KiB |
166
README.md
166
README.md
@@ -1,20 +1,146 @@
|
||||
# Introduction
|
||||
TODO: Give a short introduction of your project. Let this section explain the objectives or the motivation behind this project.
|
||||
|
||||
# Getting Started
|
||||
TODO: Guide users through getting your code up and running on their own system. In this section you can talk about:
|
||||
1. Installation process
|
||||
2. Software dependencies
|
||||
3. Latest releases
|
||||
4. API references
|
||||
|
||||
# Build and Test
|
||||
TODO: Describe and show how to build your code and run the tests.
|
||||
|
||||
# Contribute
|
||||
TODO: Explain how other users and developers can contribute to make your code better.
|
||||
|
||||
If you want to learn more about creating good readme files then refer the following [guidelines](https://docs.microsoft.com/en-us/azure/devops/repos/git/create-a-readme?view=azure-devops). You can also seek inspiration from the below readme files:
|
||||
- [ASP.NET Core](https://github.com/aspnet/Home)
|
||||
- [Visual Studio Code](https://github.com/Microsoft/vscode)
|
||||
- [Chakra Core](https://github.com/Microsoft/ChakraCore)
|
||||
# PrimAITE
|
||||
|
||||

|
||||
|
||||
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for the purposes of training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment, which includes:
|
||||
|
||||
- The ability to model a relevant platform / system context;
|
||||
|
||||
- The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, traffic loading, operating systems and services;
|
||||
|
||||
- Operates at machine-speed to enable fast training cycles.
|
||||
|
||||
PrimAITE presents the following features:
|
||||
|
||||
- Highly configurable (via YAML files) to provide the means to model a variety of platform / system laydowns, mission profiles and adversarial attack scenarios;
|
||||
|
||||
- A Reinforcement Learning (RL) reward function based on (a) the ability to counter the specific modelled adversarial cyber-attack, and (b) the ability to ensure mission success;
|
||||
|
||||
- Provision of logging to support AI evaluation and metrics gathering;
|
||||
|
||||
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life, adversarial behaviour and mission data (on a sliding scale of criticality);
|
||||
|
||||
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
|
||||
|
||||
- Application of IERs to the platform / system laydown adheres to the ACL ruleset;
|
||||
|
||||
- Presents an OpenAI gym or RLLib interface to the environment, allowing integration with any compliant defensive agents;
|
||||
|
||||
- Full capture of discrete logs relating to agent training (full system state, agent actions taken, instantaneous and average reward for every step of every episode);
|
||||
|
||||
- NetworkX provides laydown visualisation capability.
|
||||
|
||||
## Getting Started with PrimAITE
|
||||
|
||||
### 💫 Install & Run
|
||||
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
|
||||
Currently, the PrimAITE wheel can only be installed from GitHub. This may change in the future with release to PyPi.
|
||||
|
||||
#### Windows (PowerShell)
|
||||
|
||||
**Prerequisites:**
|
||||
* Manual install of Python >= 3.8 < 3.11
|
||||
|
||||
**Install:**
|
||||
|
||||
``` powershell
|
||||
mkdir ~\primaite
|
||||
cd ~\primaite
|
||||
python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
.\.venv\Scripts\activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
primaite setup
|
||||
```
|
||||
|
||||
**Run:**
|
||||
|
||||
``` bash
|
||||
primaite session
|
||||
```
|
||||
|
||||
#### Unix
|
||||
|
||||
**Prerequisites:**
|
||||
* Manual install of Python >= 3.8 < 3.11
|
||||
|
||||
``` bash
|
||||
sudo add-apt-repository ppa:deadsnakes/ppa
|
||||
sudo apt install python3.10
|
||||
sudo apt-get install python3-pip
|
||||
sudo apt-get install python3-venv
|
||||
```
|
||||
**Install:**
|
||||
|
||||
``` bash
|
||||
mkdir ~/primaite
|
||||
cd ~/primaite
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
primaite setup
|
||||
```
|
||||
|
||||
**Run:**
|
||||
|
||||
``` bash
|
||||
primaite session
|
||||
```
|
||||
|
||||
|
||||
### Developer Install from Source
|
||||
To make your own changes to PrimAITE, perform the install from source (developer install)
|
||||
|
||||
#### 1. Clone the PrimAITE repository
|
||||
``` unix
|
||||
git clone git@github.com:Autonomous-Resilient-Cyber-Defence/PrimAITE.git
|
||||
```
|
||||
|
||||
#### 2. CD into the repo directory
|
||||
``` unix
|
||||
cd PrimAITE
|
||||
```
|
||||
#### 3. Create a new python virtual environment (venv)
|
||||
|
||||
```unix
|
||||
python3 -m venv venv
|
||||
```
|
||||
|
||||
#### 4. Activate the venv
|
||||
|
||||
##### Unix
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
##### Windows (Powershell)
|
||||
```powershell
|
||||
.\venv\Scripts\activate
|
||||
```
|
||||
|
||||
#### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
```
|
||||
|
||||
#### 6. Perform the PrimAITE setup:
|
||||
|
||||
```bash
|
||||
primaite setup
|
||||
```
|
||||
|
||||
## 📚 Building documentation
|
||||
The PrimAITE documentation can be built with the following commands:
|
||||
|
||||
##### Unix
|
||||
```bash
|
||||
cd docs
|
||||
make html
|
||||
```
|
||||
|
||||
##### Windows (Powershell)
|
||||
```powershell
|
||||
cd docs
|
||||
.\make.bat html
|
||||
```
|
||||
|
||||
164
benchmark/config/benchmark_training_config.yaml
Normal file
164
benchmark/config/benchmark_training_config.yaml
Normal file
@@ -0,0 +1,164 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
|
||||
# observation space
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 500
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 1
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 0
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
449
benchmark/primaite_benchmark.py
Normal file
449
benchmark/primaite_benchmark.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import GPUtil
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import psutil
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
from pylatex import Command, Document
|
||||
from pylatex import Figure as LatexFigure
|
||||
from pylatex import Section, Subsection, Tabular
|
||||
from pylatex.utils import bold
|
||||
|
||||
import primaite
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.data_viz.session_plots import get_plotly_config
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
_LOGGER = primaite.getLogger(__name__)
|
||||
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
|
||||
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
|
||||
# Clear and recreate the output directory
|
||||
if _OUTPUT_ROOT.exists():
|
||||
shutil.rmtree(_OUTPUT_ROOT)
|
||||
_OUTPUT_ROOT.mkdir()
|
||||
|
||||
_TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.yaml"
|
||||
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
|
||||
|
||||
|
||||
def get_size(size_bytes: int):
|
||||
"""
|
||||
Scale bytes to its proper format.
|
||||
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
|
||||
:
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if size_bytes < factor:
|
||||
return f"{size_bytes:.2f}{unit}B"
|
||||
size_bytes /= factor
|
||||
|
||||
|
||||
def _get_system_info() -> Dict:
|
||||
"""Builds and returns a dict containing system info."""
|
||||
uname = platform.uname()
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
virtual_mem = psutil.virtual_memory()
|
||||
swap_mem = psutil.swap_memory()
|
||||
gpus = GPUtil.getGPUs()
|
||||
return {
|
||||
"System": {
|
||||
"OS": uname.system,
|
||||
"OS Version": uname.version,
|
||||
"Machine": uname.machine,
|
||||
"Processor": uname.processor,
|
||||
},
|
||||
"CPU": {
|
||||
"Physical Cores": psutil.cpu_count(logical=False),
|
||||
"Total Cores": psutil.cpu_count(logical=True),
|
||||
"Max Frequency": f"{cpu_freq.max:.2f}Mhz",
|
||||
},
|
||||
"Memory": {"Total": get_size(virtual_mem.total), "Swap Total": get_size(swap_mem.total)},
|
||||
"GPU": [{"Name": gpu.name, "Total Memory": f"{gpu.memoryTotal}MB"} for gpu in gpus],
|
||||
}
|
||||
|
||||
|
||||
def _build_benchmark_latex_report(
|
||||
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
|
||||
):
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
|
||||
# Create a new document
|
||||
doc = Document("report", geometry_options=geometry_options)
|
||||
# Title
|
||||
doc.preamble.append(Command("title", f"PrimAITE {primaite_version} Learning Benchmark"))
|
||||
doc.preamble.append(Command("author", "PrimAITE Dev Team"))
|
||||
doc.preamble.append(Command("date", datetime.now().date()))
|
||||
doc.append(Command("maketitle"))
|
||||
|
||||
sessions = data["total_sessions"]
|
||||
episodes = data["training_config"]["num_train_episodes"]
|
||||
steps = data["training_config"]["num_train_steps"]
|
||||
|
||||
# Body
|
||||
with doc.create(Section("Introduction")):
|
||||
doc.append(
|
||||
f"PrimAITE v{primaite_version} was benchmarked automatically upon release. Learning rate metrics "
|
||||
f"were captured to be referenced during system-level testing and user acceptance testing (UAT)."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe benchmarking process consists of running {sessions} training session using the same "
|
||||
f"training and lay down config files. Each session trains an agent for {episodes} episodes, "
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
|
||||
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
with doc.create(Section("System Information")):
|
||||
with doc.create(Subsection("Python")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
table.add_row((bold("Version"), sys.version))
|
||||
table.add_hline()
|
||||
for section, section_data in data["system_info"].items():
|
||||
if section_data:
|
||||
with doc.create(Subsection(section)):
|
||||
if isinstance(section_data, dict):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for key, value in section_data.items():
|
||||
table.add_row((bold(key), value))
|
||||
table.add_hline()
|
||||
elif isinstance(section_data, list):
|
||||
headers = section_data[0].keys()
|
||||
tabs_str = "|".join(["l" for _ in range(len(headers))])
|
||||
tabs_str = f"|{tabs_str}|"
|
||||
with doc.create(Tabular(tabs_str)) as table:
|
||||
table.add_hline()
|
||||
table.add_row([bold(h) for h in headers])
|
||||
table.add_hline()
|
||||
for item in section_data:
|
||||
table.add_row(item.values())
|
||||
table.add_hline()
|
||||
|
||||
headers_map = {
|
||||
"total_sessions": "Total Sessions",
|
||||
"total_episodes": "Total Episodes",
|
||||
"total_time_steps": "Total Steps",
|
||||
"av_s_per_session": "Av Session Duration (s)",
|
||||
"av_s_per_step": "Av Step Duration (s)",
|
||||
"av_s_per_100_steps_10_nodes": "Av Duration per 100 Steps per 10 Nodes (s)",
|
||||
}
|
||||
with doc.create(Section("Stats")):
|
||||
with doc.create(Subsection("Benchmark Results")):
|
||||
with doc.create(Tabular("|l|l|")) as table:
|
||||
table.add_hline()
|
||||
for section, header in headers_map.items():
|
||||
if section.startswith("av_"):
|
||||
table.add_row((bold(header), f"{data[section]:.4f}"))
|
||||
else:
|
||||
table.add_row((bold(header), data[section]))
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
||||
|
||||
class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
"""A benchmarking primaite session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
|
||||
@property
|
||||
def env(self) -> Primaite:
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
|
||||
|
||||
def _learn_benchmark_durations(self) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Calculate and return the learning benchmark durations.
|
||||
|
||||
Calculates the:
|
||||
- Total learning time in seconds
|
||||
- Total learning time per time step in seconds
|
||||
- Total learning time per 100 time steps per 10 nodes in seconds
|
||||
|
||||
:return: The learning benchmark durations as a Tuple of three floats:
|
||||
Tuple[total_s, s_per_step, s_per_100_steps_10_nodes].
|
||||
"""
|
||||
data = self.metadata_file_as_dict()
|
||||
start_dt = datetime.fromisoformat(data["start_datetime"])
|
||||
end_dt = datetime.fromisoformat(data["end_datetime"])
|
||||
delta = end_dt - start_dt
|
||||
total_s = delta.total_seconds()
|
||||
|
||||
total_steps = data["learning"]["total_time_steps"]
|
||||
s_per_step = total_s / total_steps
|
||||
|
||||
num_nodes = self.env.num_nodes
|
||||
num_intervals = total_steps / 100
|
||||
av_interval_time = total_s / num_intervals
|
||||
s_per_100_steps_10_nodes = av_interval_time / (num_nodes / 10)
|
||||
|
||||
return total_s, s_per_step, s_per_100_steps_10_nodes
|
||||
|
||||
def learn_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""Metadata specific to the learning session."""
|
||||
total_s, s_per_step, s_per_100_steps_10_nodes = self._learn_benchmark_durations()
|
||||
return {
|
||||
"total_episodes": self.env.actual_episode_count,
|
||||
"total_time_steps": self.env.total_step_count,
|
||||
"total_s": total_s,
|
||||
"s_per_step": s_per_step,
|
||||
"s_per_100_steps_10_nodes": s_per_100_steps_10_nodes,
|
||||
"av_reward_per_episode": self.learn_av_reward_per_episode_dict(),
|
||||
}
|
||||
|
||||
|
||||
def _get_benchmark_session_path(session_timestamp: datetime) -> Path:
|
||||
return _OUTPUT_ROOT / session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
def _get_benchmark_primaite_session() -> BenchmarkPrimaiteSession:
|
||||
with patch("primaite.agents.agent_abc.get_session_path", _get_benchmark_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
return BenchmarkPrimaiteSession(_TRAINING_CONFIG_PATH, _LAY_DOWN_CONFIG_PATH)
|
||||
|
||||
|
||||
def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) -> dict:
|
||||
n = len(metadata_dict)
|
||||
with open(_TRAINING_CONFIG_PATH, "r") as file:
|
||||
training_config_dict = yaml.safe_load(file)
|
||||
with open(_LAY_DOWN_CONFIG_PATH, "r") as file:
|
||||
lay_down_config_dict = yaml.safe_load(file)
|
||||
averaged_data = {
|
||||
"start_timestamp": start_datetime.isoformat(),
|
||||
"end_datetime": datetime.now().isoformat(),
|
||||
"primaite_version": primaite.__version__,
|
||||
"system_info": _get_system_info(),
|
||||
"total_sessions": n,
|
||||
"total_episodes": sum(d["total_episodes"] for d in metadata_dict.values()),
|
||||
"total_time_steps": sum(d["total_time_steps"] for d in metadata_dict.values()),
|
||||
"av_s_per_session": sum(d["total_s"] for d in metadata_dict.values()) / n,
|
||||
"av_s_per_step": sum(d["s_per_step"] for d in metadata_dict.values()) / n,
|
||||
"av_s_per_100_steps_10_nodes": sum(d["s_per_100_steps_10_nodes"] for d in metadata_dict.values()) / n,
|
||||
"combined_av_reward_per_episode": {},
|
||||
"session_av_reward_per_episode": {k: v["av_reward_per_episode"] for k, v in metadata_dict.items()},
|
||||
"training_config": training_config_dict,
|
||||
"lay_down_config": lay_down_config_dict,
|
||||
}
|
||||
|
||||
episodes = metadata_dict[1]["av_reward_per_episode"].keys()
|
||||
|
||||
for episode in episodes:
|
||||
combined_av_reward = sum(metadata_dict[k]["av_reward_per_episode"][episode] for k in metadata_dict.keys()) / n
|
||||
averaged_data["combined_av_reward_per_episode"][episode] = combined_av_reward
|
||||
|
||||
return averaged_data
|
||||
|
||||
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict):
|
||||
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
|
||||
|
||||
return (
|
||||
pl.from_dict(data)
|
||||
.with_columns(rolling_mean=pl.col("av_reward").rolling_mean(window_size=25))
|
||||
.rename({"rolling_mean": "rolling_av_reward"})
|
||||
)
|
||||
|
||||
|
||||
def _plot_benchmark_metadata(
|
||||
benchmark_metadata_dict: Dict,
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
|
||||
for session, av_reward_dict in benchmark_metadata_dict["session_av_reward_per_episode"].items():
|
||||
df = _get_df_from_episode_av_reward_dict(av_reward_dict)
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["av_reward"],
|
||||
mode="lines",
|
||||
name=f"Session {session}",
|
||||
opacity=0.25,
|
||||
line={"color": "#a6a6a6"},
|
||||
)
|
||||
)
|
||||
|
||||
df = _get_df_from_episode_av_reward_dict(benchmark_metadata_dict["combined_av_reward_per_episode"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"], y=df["av_reward"], mode="lines", name="Combined Session Av", line={"color": "#FF0000"}
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["episode"],
|
||||
y=df["rolling_av_reward"],
|
||||
mode="lines",
|
||||
name="Rolling Av (Combined Session Av)",
|
||||
line={"color": "#4CBB17"},
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_all_benchmarks_combined_session_av():
|
||||
"""
|
||||
Plot the Benchmark results for each released version of PrimAITE.
|
||||
|
||||
Does this by iterating over the ``benchmark/results`` directory and
|
||||
extracting the benchmark metadata json for each version that has been
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
"""
|
||||
title = "PrimAITE Versions Learning Benchmark"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
|
||||
for dir in _RESULTS_ROOT.iterdir():
|
||||
if dir.is_dir():
|
||||
metadata_file = dir / f"{dir.name}_benchmark_metadata.json"
|
||||
with open(metadata_file, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
df = _get_df_from_episode_av_reward_dict(metadata_dict["combined_av_reward_per_episode"])
|
||||
|
||||
fig.add_trace(go.Scatter(x=df["episode"], y=df["rolling_av_reward"], mode="lines", name=dir.name))
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def run():
|
||||
"""Run the PrimAITE benchmark."""
|
||||
start_datetime = datetime.now()
|
||||
av_reward_per_episode_dicts = {}
|
||||
for i in range(1, 11):
|
||||
print(f"Starting Benchmark Session: {i}")
|
||||
with _get_benchmark_primaite_session() as session:
|
||||
session.learn()
|
||||
av_reward_per_episode_dicts[i] = session.learn_metadata_dict()
|
||||
|
||||
benchmark_metadata = _build_benchmark_results_dict(
|
||||
start_datetime=start_datetime, metadata_dict=av_reward_per_episode_dicts
|
||||
)
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = _RESULTS_ROOT / v_str
|
||||
if version_result_dir.exists():
|
||||
shutil.rmtree(version_result_dir)
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
fig = _plot_benchmark_metadata(benchmark_metadata, title=title)
|
||||
this_version_plot_path = version_result_dir / f"{title}.png"
|
||||
fig.write_image(this_version_plot_path)
|
||||
|
||||
fig = _plot_all_benchmarks_combined_session_av()
|
||||
|
||||
all_version_plot_path = _RESULTS_ROOT / "PrimAITE Versions Learning Benchmark.png"
|
||||
fig.write_image(all_version_plot_path)
|
||||
|
||||
_build_benchmark_latex_report(benchmark_metadata, this_version_plot_path, all_version_plot_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
BIN
benchmark/results/PrimAITE Versions Learning Benchmark.png
Normal file
BIN
benchmark/results/PrimAITE Versions Learning Benchmark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 79 KiB |
BIN
benchmark/results/v2.0.0/PrimAITE v2.0.0 Learning Benchmark.pdf
Normal file
BIN
benchmark/results/v2.0.0/PrimAITE v2.0.0 Learning Benchmark.pdf
Normal file
Binary file not shown.
BIN
benchmark/results/v2.0.0/PrimAITE v2.0.0 Learning Benchmark.png
Normal file
BIN
benchmark/results/v2.0.0/PrimAITE v2.0.0 Learning Benchmark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 225 KiB |
6351
benchmark/results/v2.0.0/v2.0.0_benchmark_metadata.json
Normal file
6351
benchmark/results/v2.0.0/v2.0.0_benchmark_metadata.json
Normal file
File diff suppressed because it is too large
Load Diff
521
diagram/classes.puml
Normal file
521
diagram/classes.puml
Normal file
@@ -0,0 +1,521 @@
|
||||
@startuml classes
|
||||
set namespaceSeparator none
|
||||
class "ACLRule" as primaite.acl.acl_rule.ACLRule {
|
||||
dest_ip : str
|
||||
permission
|
||||
port : str
|
||||
protocol : str
|
||||
source_ip : str
|
||||
get_dest_ip() -> str
|
||||
get_permission() -> str
|
||||
get_port() -> str
|
||||
get_protocol() -> str
|
||||
get_source_ip() -> str
|
||||
}
|
||||
class "AbstractObservationComponent" as primaite.environment.observations.AbstractObservationComponent {
|
||||
current_observation : NotImplementedType, ndarray
|
||||
env : str
|
||||
space : Space
|
||||
structure : List[str]
|
||||
{abstract}generate_structure() -> List[str]
|
||||
{abstract}update() -> None
|
||||
}
|
||||
class "AccessControlList" as primaite.acl.access_control_list.AccessControlList {
|
||||
acl
|
||||
acl_implicit_permission
|
||||
acl_implicit_rule
|
||||
max_acl_rules : int
|
||||
add_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str, _position: str) -> None
|
||||
check_address_match(_rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool
|
||||
get_dictionary_hash(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> int
|
||||
get_relevant_rules(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> Dict[int, ACLRule]
|
||||
is_blocked(_source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool
|
||||
remove_all_rules() -> None
|
||||
remove_rule(_permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str) -> None
|
||||
}
|
||||
class "AccessControlList_" as primaite.environment.observations.AccessControlList_ {
|
||||
current_observation : ndarray
|
||||
space : MultiDiscrete
|
||||
structure : list
|
||||
generate_structure() -> List[str]
|
||||
update() -> None
|
||||
}
|
||||
|
||||
class "ActiveNode" as primaite.nodes.active_node.ActiveNode {
|
||||
file_system_action_count : int
|
||||
file_system_scanning : bool
|
||||
file_system_scanning_count : int
|
||||
file_system_state_actual : GOOD
|
||||
file_system_state_observed : REPAIRING, RESTORING, GOOD
|
||||
ip_address : str
|
||||
patching_count : int
|
||||
software_state
|
||||
software_state : GOOD
|
||||
set_file_system_state(file_system_state: FileSystemState) -> None
|
||||
set_file_system_state_if_not_compromised(file_system_state: FileSystemState) -> None
|
||||
set_software_state_if_not_compromised(software_state: SoftwareState) -> None
|
||||
start_file_system_scan() -> None
|
||||
update_booting_status() -> None
|
||||
update_file_system_state() -> None
|
||||
update_os_patching_status() -> None
|
||||
update_resetting_status() -> None
|
||||
}
|
||||
class "AgentSessionABC" as primaite.agents.agent_abc.AgentSessionABC {
|
||||
checkpoints_path
|
||||
evaluation_path
|
||||
is_eval : bool
|
||||
learning_path
|
||||
sb3_output_verbose_level : NONE
|
||||
session_path : Union[str, Path]
|
||||
session_timestamp : datetime
|
||||
timestamp_str
|
||||
uuid
|
||||
close() -> None
|
||||
{abstract}evaluate() -> None
|
||||
{abstract}export() -> None
|
||||
{abstract}learn() -> None
|
||||
load(path: Union[str, Path]) -> None
|
||||
{abstract}save() -> None
|
||||
}
|
||||
|
||||
class "DoNothingACLAgent" as primaite.agents.simple.DoNothingACLAgent {
|
||||
}
|
||||
class "DoNothingNodeAgent" as primaite.agents.simple.DoNothingNodeAgent {
|
||||
}
|
||||
class "DummyAgent" as primaite.agents.simple.DummyAgent {
|
||||
}
|
||||
class "HardCodedACLAgent" as primaite.agents.hardcoded_acl.HardCodedACLAgent {
|
||||
get_allow_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule]
|
||||
get_allow_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
|
||||
get_blocked_green_iers(green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[str, IER]
|
||||
get_blocking_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
|
||||
get_deny_acl_rules(source_node_id: int, dest_node_id: str, protocol: int, port: str, acl: AccessControlList, nodes: Dict[str, NodeUnion], services_list: List[str]) -> Dict[int, ACLRule]
|
||||
get_matching_acl_rules(source_node_id: str, dest_node_id: str, protocol: str, port: str, acl: AccessControlList, nodes: Dict[str, Union[ServiceNode, ActiveNode]], services_list: List[str]) -> Dict[int, ACLRule]
|
||||
get_matching_acl_rules_for_ier(ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]) -> Dict[int, ACLRule]
|
||||
}
|
||||
class "HardCodedAgentSessionABC" as primaite.agents.hardcoded_abc.HardCodedAgentSessionABC {
|
||||
is_eval : bool
|
||||
evaluate() -> None
|
||||
export() -> None
|
||||
learn() -> None
|
||||
load(path: Union[str, Path]) -> None
|
||||
save() -> None
|
||||
}
|
||||
class "HardCodedNodeAgent" as primaite.agents.hardcoded_node.HardCodedNodeAgent {
|
||||
}
|
||||
class "IER" as primaite.pol.ier.IER {
|
||||
dest_node_id : str
|
||||
end_step : int
|
||||
id : str
|
||||
load : int
|
||||
mission_criticality : int
|
||||
port : str
|
||||
protocol : str
|
||||
running : bool
|
||||
source_node_id : str
|
||||
start_step : int
|
||||
get_dest_node_id() -> str
|
||||
get_end_step() -> int
|
||||
get_id() -> str
|
||||
get_is_running() -> bool
|
||||
get_load() -> int
|
||||
get_mission_criticality() -> int
|
||||
get_port() -> str
|
||||
get_protocol() -> str
|
||||
get_source_node_id() -> str
|
||||
get_start_step() -> int
|
||||
set_is_running(_value: bool) -> None
|
||||
}
|
||||
class "Link" as primaite.links.link.Link {
|
||||
bandwidth : int
|
||||
dest_node_name : str
|
||||
id : str
|
||||
protocol_list : List[Protocol]
|
||||
source_node_name : str
|
||||
add_protocol(_protocol: str) -> None
|
||||
add_protocol_load(_protocol: str, _load: int) -> None
|
||||
clear_traffic() -> None
|
||||
get_bandwidth() -> int
|
||||
get_current_load() -> int
|
||||
get_dest_node_name() -> str
|
||||
get_id() -> str
|
||||
get_protocol_list() -> List[Protocol]
|
||||
get_source_node_name() -> str
|
||||
}
|
||||
class "LinkTrafficLevels" as primaite.environment.observations.LinkTrafficLevels {
|
||||
current_observation : ndarray
|
||||
space : MultiDiscrete
|
||||
structure : list
|
||||
generate_structure() -> List[str]
|
||||
update() -> None
|
||||
}
|
||||
class "Node" as primaite.nodes.node.Node {
|
||||
booting_count : int
|
||||
config_values
|
||||
hardware_state : BOOTING, ON, RESETTING, OFF
|
||||
name : Final[str]
|
||||
node_id : Final[str]
|
||||
node_type : Final[NodeType]
|
||||
priority
|
||||
resetting_count : int
|
||||
shutting_down_count : int
|
||||
reset() -> None
|
||||
turn_off() -> None
|
||||
turn_on() -> None
|
||||
update_booting_status() -> None
|
||||
update_resetting_status() -> None
|
||||
update_shutdown_status() -> None
|
||||
}
|
||||
class "NodeLinkTable" as primaite.environment.observations.NodeLinkTable {
|
||||
current_observation : ndarray
|
||||
space : Box
|
||||
structure : list
|
||||
generate_structure() -> List[str]
|
||||
update() -> None
|
||||
}
|
||||
class "NodeStateInstructionGreen" as primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen {
|
||||
end_step : int
|
||||
id : str
|
||||
node_id : str
|
||||
node_pol_type : str
|
||||
service_name : str
|
||||
start_step : int
|
||||
state : Union['HardwareState', 'SoftwareState', 'FileSystemState']
|
||||
get_end_step() -> int
|
||||
get_node_id() -> str
|
||||
get_node_pol_type() -> 'NodePOLType'
|
||||
get_service_name() -> str
|
||||
get_start_step() -> int
|
||||
get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState']
|
||||
}
|
||||
class "NodeStateInstructionRed" as primaite.nodes.node_state_instruction_red.NodeStateInstructionRed {
|
||||
end_step : int
|
||||
id : str
|
||||
initiator : str
|
||||
pol_type
|
||||
service_name : str
|
||||
source_node_id : str
|
||||
source_node_service : str
|
||||
source_node_service_state : str
|
||||
start_step : int
|
||||
state : Union['HardwareState', 'SoftwareState', 'FileSystemState']
|
||||
target_node_id : str
|
||||
get_end_step() -> int
|
||||
get_initiator() -> 'NodePOLInitiator'
|
||||
get_pol_type() -> NodePOLType
|
||||
get_service_name() -> str
|
||||
get_source_node_id() -> str
|
||||
get_source_node_service() -> str
|
||||
get_source_node_service_state() -> str
|
||||
get_start_step() -> int
|
||||
get_state() -> Union['HardwareState', 'SoftwareState', 'FileSystemState']
|
||||
get_target_node_id() -> str
|
||||
}
|
||||
class "NodeStatuses" as primaite.environment.observations.NodeStatuses {
|
||||
current_observation : ndarray
|
||||
space : MultiDiscrete
|
||||
structure : list
|
||||
generate_structure() -> List[str]
|
||||
update() -> None
|
||||
}
|
||||
class "ObservationsHandler" as primaite.environment.observations.ObservationsHandler {
|
||||
current_observation
|
||||
registered_obs_components : List[AbstractObservationComponent]
|
||||
space
|
||||
deregister(obs_component: AbstractObservationComponent) -> None
|
||||
describe_structure() -> List[str]
|
||||
from_config(env: 'Primaite', obs_space_config: dict) -> 'ObservationsHandler'
|
||||
register(obs_component: AbstractObservationComponent) -> None
|
||||
update_obs() -> None
|
||||
update_space() -> None
|
||||
}
|
||||
class "PassiveNode" as primaite.nodes.passive_node.PassiveNode {
|
||||
ip_address
|
||||
}
|
||||
class "Primaite" as primaite.environment.primaite_env.Primaite {
|
||||
ACTION_SPACE_ACL_ACTION_VALUES : int
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES : int
|
||||
ACTION_SPACE_NODE_ACTION_VALUES : int
|
||||
ACTION_SPACE_NODE_PROPERTY_VALUES : int
|
||||
acl
|
||||
action_dict : dict, Dict[int, List[int]]
|
||||
action_space : Discrete, Space
|
||||
action_type : int
|
||||
actual_episode_count
|
||||
agent_identifier
|
||||
average_reward : float
|
||||
env_obs : ndarray, tuple
|
||||
episode_av_reward_writer
|
||||
episode_count : int
|
||||
episode_steps : int
|
||||
green_iers : Dict[str, IER]
|
||||
green_iers_reference : Dict[str, IER]
|
||||
lay_down_config
|
||||
links : Dict[str, Link]
|
||||
links_post_blue : dict
|
||||
links_post_pol : dict
|
||||
links_post_red : dict
|
||||
links_reference : Dict[str, Link]
|
||||
max_number_acl_rules : int
|
||||
network : Graph
|
||||
network_reference : Graph
|
||||
node_pol : Dict[str, NodeStateInstructionGreen]
|
||||
nodes : Dict[str, NodeUnion]
|
||||
nodes_post_blue : dict
|
||||
nodes_post_pol : dict
|
||||
nodes_post_red : dict
|
||||
nodes_reference : Dict[str, NodeUnion]
|
||||
num_links : int
|
||||
num_nodes : int
|
||||
num_ports : int
|
||||
num_services : int
|
||||
obs_config : dict
|
||||
obs_handler
|
||||
observation_space : Tuple, Box, Space
|
||||
observation_type
|
||||
ports_list : List[str]
|
||||
red_iers : Dict[str, IER], dict
|
||||
red_node_pol : dict, Dict[str, NodeStateInstructionRed]
|
||||
services_list : List[str]
|
||||
session_path : Final[Path]
|
||||
step_count : int
|
||||
step_info : Dict[Any]
|
||||
timestamp_str : Final[str]
|
||||
total_reward : float
|
||||
total_step_count : int
|
||||
training_config
|
||||
transaction_writer
|
||||
apply_actions_to_acl(_action: int) -> None
|
||||
apply_actions_to_nodes(_action: int) -> None
|
||||
apply_time_based_updates() -> None
|
||||
close() -> None
|
||||
create_acl_action_dict() -> Dict[int, List[int]]
|
||||
create_acl_rule(item: Dict) -> None
|
||||
create_green_ier(item: Dict) -> None
|
||||
create_green_pol(item: Dict) -> None
|
||||
create_link(item: Dict) -> None
|
||||
create_node(item: Dict) -> None
|
||||
create_node_action_dict() -> Dict[int, List[int]]
|
||||
create_node_and_acl_action_dict() -> Dict[int, List[int]]
|
||||
create_ports_list(ports: Dict) -> None
|
||||
create_red_ier(item: Dict) -> None
|
||||
create_red_pol(item: Dict) -> None
|
||||
create_services_list(services: Dict) -> None
|
||||
get_action_info(action_info: Dict) -> None
|
||||
get_observation_info(observation_info: Dict) -> None
|
||||
init_acl() -> None
|
||||
init_observations() -> Tuple[spaces.Space, np.ndarray]
|
||||
interpret_action_and_apply(_action: int) -> None
|
||||
load_lay_down_config() -> None
|
||||
output_link_status() -> None
|
||||
reset() -> np.ndarray
|
||||
reset_environment() -> None
|
||||
reset_node(item: Dict) -> None
|
||||
save_obs_config(obs_config: dict) -> None
|
||||
set_as_eval() -> None
|
||||
step(action: int) -> Tuple[np.ndarray, float, bool, Dict]
|
||||
update_environent_obs() -> None
|
||||
}
|
||||
class "PrimaiteSession" as primaite.primaite_session.PrimaiteSession {
|
||||
evaluation_path : Optional[Path], Path
|
||||
is_load_session : bool
|
||||
learning_path : Optional[Path], Path
|
||||
session_path : Optional[Path], Path
|
||||
timestamp_str : str, Optional[str]
|
||||
close() -> None
|
||||
evaluate() -> None
|
||||
learn() -> None
|
||||
setup() -> None
|
||||
}
|
||||
class "Protocol" as primaite.common.protocol.Protocol {
|
||||
load : int
|
||||
name : str
|
||||
add_load(_load: int) -> None
|
||||
clear_load() -> None
|
||||
get_load() -> int
|
||||
get_name() -> str
|
||||
}
|
||||
class "RLlibAgent" as primaite.agents.rllib.RLlibAgent {
|
||||
{abstract}evaluate() -> None
|
||||
{abstract}export() -> None
|
||||
learn() -> None
|
||||
{abstract}load(path: Union[str, Path]) -> RLlibAgent
|
||||
save(overwrite_existing: bool) -> None
|
||||
}
|
||||
class "RandomAgent" as primaite.agents.simple.RandomAgent {
|
||||
}
|
||||
class "SB3Agent" as primaite.agents.sb3.SB3Agent {
|
||||
is_eval : bool
|
||||
evaluate() -> None
|
||||
{abstract}export() -> None
|
||||
learn() -> None
|
||||
save() -> None
|
||||
}
|
||||
class "Service" as primaite.common.service.Service {
|
||||
name : str
|
||||
patching_count : int
|
||||
port : str
|
||||
software_state : GOOD
|
||||
reduce_patching_count() -> None
|
||||
}
|
||||
class "ServiceNode" as primaite.nodes.service_node.ServiceNode {
|
||||
services : Dict[str, Service]
|
||||
add_service(service: Service) -> None
|
||||
get_service_state(protocol_name: str) -> SoftwareState
|
||||
has_service(protocol_name: str) -> bool
|
||||
service_is_overwhelmed(protocol_name: str) -> bool
|
||||
service_running(protocol_name: str) -> bool
|
||||
set_service_state(protocol_name: str, software_state: SoftwareState) -> None
|
||||
set_service_state_if_not_compromised(protocol_name: str, software_state: SoftwareState) -> None
|
||||
update_booting_status() -> None
|
||||
update_resetting_status() -> None
|
||||
update_services_patching_status() -> None
|
||||
}
|
||||
class "SessionOutputWriter" as primaite.utils.session_output_writer.SessionOutputWriter {
|
||||
learning_session : bool
|
||||
transaction_writer : bool
|
||||
close() -> None
|
||||
write(data: Union[Tuple, Transaction]) -> None
|
||||
}
|
||||
class "TrainingConfig" as primaite.config.training_config.TrainingConfig {
|
||||
action_type
|
||||
agent_framework
|
||||
agent_identifier
|
||||
agent_load_file : Optional[str]
|
||||
all_ok : float
|
||||
checkpoint_every_n_episodes : int
|
||||
compromised : float
|
||||
compromised_should_be_good : float
|
||||
compromised_should_be_overwhelmed : float
|
||||
compromised_should_be_patching : float
|
||||
corrupt : float
|
||||
corrupt_should_be_destroyed : float
|
||||
corrupt_should_be_good : float
|
||||
corrupt_should_be_repairing : float
|
||||
corrupt_should_be_restoring : float
|
||||
deep_learning_framework
|
||||
destroyed : float
|
||||
destroyed_should_be_corrupt : float
|
||||
destroyed_should_be_good : float
|
||||
destroyed_should_be_repairing : float
|
||||
destroyed_should_be_restoring : float
|
||||
deterministic : bool
|
||||
file_system_repairing_limit : int
|
||||
file_system_restoring_limit : int
|
||||
file_system_scanning_limit : int
|
||||
good_should_be_compromised : float
|
||||
good_should_be_corrupt : float
|
||||
good_should_be_destroyed : float
|
||||
good_should_be_overwhelmed : float
|
||||
good_should_be_patching : float
|
||||
good_should_be_repairing : float
|
||||
good_should_be_restoring : float
|
||||
green_ier_blocked : float
|
||||
hard_coded_agent_view
|
||||
implicit_acl_rule
|
||||
load_agent : bool
|
||||
max_number_acl_rules : int
|
||||
node_booting_duration : int
|
||||
node_reset_duration : int
|
||||
node_shutdown_duration : int
|
||||
num_eval_episodes : int
|
||||
num_eval_steps : int
|
||||
num_train_episodes : int
|
||||
num_train_steps : int
|
||||
observation_space : dict
|
||||
observation_space_high_value : int
|
||||
off_should_be_on : float
|
||||
off_should_be_resetting : float
|
||||
on_should_be_off : float
|
||||
on_should_be_resetting : float
|
||||
os_patching_duration : int
|
||||
overwhelmed : float
|
||||
overwhelmed_should_be_compromised : float
|
||||
overwhelmed_should_be_good : float
|
||||
overwhelmed_should_be_patching : float
|
||||
patching : float
|
||||
patching_should_be_compromised : float
|
||||
patching_should_be_good : float
|
||||
patching_should_be_overwhelmed : float
|
||||
random_red_agent : bool
|
||||
red_ier_running : float
|
||||
repairing : float
|
||||
repairing_should_be_corrupt : float
|
||||
repairing_should_be_destroyed : float
|
||||
repairing_should_be_good : float
|
||||
repairing_should_be_restoring : float
|
||||
resetting : float
|
||||
resetting_should_be_off : float
|
||||
resetting_should_be_on : float
|
||||
restoring : float
|
||||
restoring_should_be_corrupt : float
|
||||
restoring_should_be_destroyed : float
|
||||
restoring_should_be_good : float
|
||||
restoring_should_be_repairing : float
|
||||
sb3_output_verbose_level
|
||||
scanning : float
|
||||
seed : Optional[int]
|
||||
service_patching_duration : int
|
||||
session_type
|
||||
time_delay : int
|
||||
from_dict(config_dict: Dict[str, Any]) -> TrainingConfig
|
||||
to_dict(json_serializable: bool) -> Dict
|
||||
}
|
||||
class "Transaction" as primaite.transactions.transaction.Transaction {
|
||||
action_space : Optional[int]
|
||||
agent_identifier
|
||||
episode_number : int
|
||||
obs_space : str
|
||||
obs_space_description : NoneType, Optional[List[str]], list
|
||||
obs_space_post : Optional[Union['np.ndarray', Tuple['np.ndarray']]]
|
||||
obs_space_pre : Optional[Union['np.ndarray', Tuple['np.ndarray']]]
|
||||
reward : Optional[float], float
|
||||
step_number : int
|
||||
timestamp : datetime
|
||||
as_csv_data() -> Tuple[List, List]
|
||||
}
|
||||
primaite.agents.hardcoded_abc.HardCodedAgentSessionABC --|> primaite.agents.agent_abc.AgentSessionABC
|
||||
primaite.agents.hardcoded_acl.HardCodedACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
|
||||
primaite.agents.hardcoded_node.HardCodedNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
|
||||
primaite.agents.rllib.RLlibAgent --|> primaite.agents.agent_abc.AgentSessionABC
|
||||
primaite.agents.sb3.SB3Agent --|> primaite.agents.agent_abc.AgentSessionABC
|
||||
primaite.agents.simple.DoNothingACLAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
|
||||
primaite.agents.simple.DoNothingNodeAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
|
||||
primaite.agents.simple.DummyAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
|
||||
primaite.agents.simple.RandomAgent --|> primaite.agents.hardcoded_abc.HardCodedAgentSessionABC
|
||||
primaite.environment.observations.AccessControlList_ --|> primaite.environment.observations.AbstractObservationComponent
|
||||
primaite.environment.observations.LinkTrafficLevels --|> primaite.environment.observations.AbstractObservationComponent
|
||||
primaite.environment.observations.NodeLinkTable --|> primaite.environment.observations.AbstractObservationComponent
|
||||
primaite.environment.observations.NodeStatuses --|> primaite.environment.observations.AbstractObservationComponent
|
||||
primaite.nodes.active_node.ActiveNode --|> primaite.nodes.node.Node
|
||||
primaite.nodes.passive_node.PassiveNode --|> primaite.nodes.node.Node
|
||||
primaite.nodes.service_node.ServiceNode --|> primaite.nodes.active_node.ActiveNode
|
||||
primaite.common.service.Service --|> primaite.nodes.service_node.ServiceNode
|
||||
primaite.acl.access_control_list.AccessControlList --* primaite.environment.primaite_env.Primaite : acl
|
||||
primaite.acl.acl_rule.ACLRule --* primaite.acl.access_control_list.AccessControlList : acl_implicit_rule
|
||||
primaite.agents.hardcoded_acl.HardCodedACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.hardcoded_node.HardCodedNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.rllib.RLlibAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.sb3.SB3Agent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.simple.DoNothingACLAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.simple.DoNothingNodeAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.simple.DummyAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.agents.simple.RandomAgent --* primaite.primaite_session.PrimaiteSession : _agent_session
|
||||
primaite.config.training_config.TrainingConfig --* primaite.agents.agent_abc.AgentSessionABC : _training_config
|
||||
primaite.config.training_config.TrainingConfig --* primaite.environment.primaite_env.Primaite : training_config
|
||||
primaite.environment.observations.ObservationsHandler --* primaite.environment.primaite_env.Primaite : obs_handler
|
||||
primaite.environment.primaite_env.Primaite --* primaite.agents.agent_abc.AgentSessionABC : _env
|
||||
primaite.environment.primaite_env.Primaite --* primaite.agents.hardcoded_abc.HardCodedAgentSessionABC : _env
|
||||
primaite.environment.primaite_env.Primaite --* primaite.agents.sb3.SB3Agent : _env
|
||||
primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : episode_av_reward_writer
|
||||
primaite.utils.session_output_writer.SessionOutputWriter --* primaite.environment.primaite_env.Primaite : transaction_writer
|
||||
primaite.config.training_config.TrainingConfig --o primaite.nodes.node.Node : config_values
|
||||
primaite.nodes.node_state_instruction_green.NodeStateInstructionGreen --* primaite.environment.primaite_env.Primaite
|
||||
primaite.nodes.node_state_instruction_red.NodeStateInstructionRed --* primaite.environment.primaite_env.Primaite
|
||||
primaite.pol.ier.IER --* primaite.environment.primaite_env.Primaite
|
||||
primaite.common.protocol.Protocol --o primaite.links.link.Link
|
||||
primaite.links.link.Link --* primaite.environment.primaite_env.Primaite
|
||||
primaite.config.training_config.TrainingConfig --o primaite.nodes.active_node.ActiveNode
|
||||
primaite.utils.session_output_writer.SessionOutputWriter --> primaite.transactions.transaction.Transaction
|
||||
primaite.transactions.transaction.Transaction --> primaite.environment.primaite_env.Primaite
|
||||
@enduml
|
||||
34
docs/Makefile
Normal file
34
docs/Makefile
Normal file
@@ -0,0 +1,34 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
AUTOSUMMARY="source\_autosummary"
|
||||
|
||||
# Remove command is different depending on OS
|
||||
ifdef OS
|
||||
RM = IF exist $(AUTOSUMMARY) ( RMDIR $(AUTOSUMMARY) /s /q )
|
||||
else
|
||||
ifeq ($(shell uname), Linux)
|
||||
RM = rm -rf $(AUTOSUMMARY)
|
||||
endif
|
||||
endif
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
clean:
|
||||
$(RM)
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile | clean
|
||||
pip-licenses --format=rst --with-urls --output-file=source/primaite-dependencies.rst
|
||||
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
0
docs/_static/.gitkeep
vendored
Normal file
0
docs/_static/.gitkeep
vendored
Normal file
41
docs/_templates/custom-class-template.rst
vendored
Normal file
41
docs/_templates/custom-class-template.rst
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
..
|
||||
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
|
||||
..
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:inherited-members:
|
||||
:special-members: __init__, __call__, __add__, __mul__
|
||||
|
||||
{% block methods %}
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
{% for item in methods %}
|
||||
{%- if not item.startswith('_') %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block attributes %}
|
||||
{% if attributes %}
|
||||
.. rubric:: {{ _('Attributes') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in attributes %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
73
docs/_templates/custom-module-template.rst
vendored
Normal file
73
docs/_templates/custom-module-template.rst
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
..
|
||||
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
|
||||
..
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. automodule:: {{ fullname }}
|
||||
|
||||
{% block attributes %}
|
||||
{% if attributes %}
|
||||
.. rubric:: Module attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
{% for item in attributes %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block functions %}
|
||||
{% if functions %}
|
||||
.. rubric:: {{ _('Functions') }}
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
:nosignatures:
|
||||
{% for item in functions %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block classes %}
|
||||
{% if classes %}
|
||||
.. rubric:: {{ _('Classes') }}
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
:template: custom-class-template.rst
|
||||
:nosignatures:
|
||||
{% for item in classes %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block exceptions %}
|
||||
{% if exceptions %}
|
||||
.. rubric:: {{ _('Exceptions') }}
|
||||
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
{% for item in exceptions %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block modules %}
|
||||
{% if modules %}
|
||||
.. autosummary::
|
||||
:toctree:
|
||||
:template: custom-module-template.rst
|
||||
:recursive:
|
||||
{% for item in modules %}
|
||||
{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
20
docs/api.rst
Normal file
20
docs/api.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
..
|
||||
DO NOT DELETE THIS FILE! It contains the all-important `.. autosummary::` directive with `:recursive:` option, without
|
||||
which API documentation wouldn't get extracted from docstrings by the `sphinx.ext.autosummary` engine. It is hidden
|
||||
(not declared in any toctree) to remove an unnecessary intermediate page; index.rst instead points directly to the
|
||||
package page. DO NOT REMOVE THIS FILE!
|
||||
|
||||
Credit to https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion for the custom templates.
|
||||
..
|
||||
|
||||
.. autosummary::
|
||||
:toctree: source/_autosummary
|
||||
:template: custom-module-template.rst
|
||||
:recursive:
|
||||
|
||||
primaite
|
||||
tests
|
||||
67
docs/build-sphinx-docs-to-github-pages.sh
Normal file
67
docs/build-sphinx-docs-to-github-pages.sh
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
apt-get update
|
||||
apt-get -y install git rsync python3-sphinx
|
||||
|
||||
pwd ls -lah
|
||||
export SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct)
|
||||
|
||||
##############
|
||||
# BUILD DOCS #
|
||||
##############
|
||||
|
||||
cd docs
|
||||
# Python Sphinx, configured with source/conf.py
|
||||
# See https://www.sphinx-doc.org/
|
||||
make clean
|
||||
make html
|
||||
|
||||
cd ..
|
||||
#######################
|
||||
# Update GitHub Pages #
|
||||
#######################
|
||||
|
||||
git config --global user.name "${GITHUB_ACTOR}"
|
||||
git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
|
||||
|
||||
docroot=`mktemp -d`
|
||||
|
||||
rsync -av $PWD/docs/_build/html/ "${docroot}/"
|
||||
|
||||
pushd "${docroot}"
|
||||
|
||||
git init
|
||||
git remote add deploy "https://token:${GITHUB_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
|
||||
git checkout -b sphinx-docs-github-pages
|
||||
|
||||
# Adds .nojekyll file to the root to signal to GitHub that
|
||||
# directories that start with an underscore (_) can remain
|
||||
touch .nojekyll
|
||||
|
||||
# Add README
|
||||
cat > README.md <<EOF
|
||||
# README for the Sphinx Docs GitHub Pages Branch
|
||||
This branch is simply a cache for the website served from https://Autonomous-Resilient-Cyber-Defence.github.io/PrimAITE/,
|
||||
and is not intended to be viewed on github.com.
|
||||
For more information on how this site is built using Sphinx, Read the Docs, GitHub Actions/Pages, and demo
|
||||
implementation from https://github.com/annegentle, see:
|
||||
* https://www.docslikecode.com/articles/github-pages-python-sphinx/
|
||||
* https://tech.michaelaltfield.net/2020/07/18/sphinx-rtd-github-pages-1
|
||||
* https://github.com/annegentle/create-demo
|
||||
EOF
|
||||
|
||||
# Copy the resulting html pages built from Sphinx to the sphinx-docs-github-pages branch
|
||||
git add .
|
||||
|
||||
# Make a commit with changes and any new files
|
||||
msg="Updating Docs for commit ${GITHUB_SHA} made on `date -d"@${SOURCE_DATE_EPOCH}" --iso-8601=seconds` from ${GITHUB_REF} by ${GITHUB_ACTOR}"
|
||||
git commit -am "${msg}"
|
||||
|
||||
# overwrite the contents of the sphinx-docs-github-pages branch on our github.com repo
|
||||
git push deploy sphinx-docs-github-pages --force
|
||||
|
||||
popd # return to main repo sandbox root
|
||||
|
||||
# exit cleanly
|
||||
exit 0
|
||||
57
docs/conf.py
Normal file
57
docs/conf.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
import datetime
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
import os
|
||||
import sys
|
||||
|
||||
import furo # noqa
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../"))
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
year = datetime.datetime.now().year
|
||||
project = "PrimAITE"
|
||||
copyright = f"Copyright (C) Defence Science and Technology Laboratory UK 2021 - {year}"
|
||||
author = "Defence Science and Technology Laboratory UK"
|
||||
|
||||
# The short Major.Minor.Build version
|
||||
with open("../src/primaite/VERSION", "r") as file:
|
||||
version = file.readline()
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = version
|
||||
|
||||
html_title = f"{project} v{release} docs"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc", # Core Sphinx library for auto html doc generation from docstrings
|
||||
"sphinx.ext.autosummary", # Create summary tables for modules/classes/methods etc
|
||||
"sphinx.ext.intersphinx", # Link to other project's documentation (see mapping below)
|
||||
"sphinx.ext.viewcode", # Add a link to the Python source code for classes, functions etc.
|
||||
"sphinx.ext.todo",
|
||||
"sphinx_copybutton", # Adds a copy button to code blocks
|
||||
]
|
||||
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "furo"
|
||||
html_static_path = ["_static"]
|
||||
119
docs/index.rst
Normal file
119
docs/index.rst
Normal file
@@ -0,0 +1,119 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
Welcome to PrimAITE's documentation
|
||||
====================================
|
||||
|
||||
What is PrimAITE?
|
||||
-----------------
|
||||
|
||||
Overview
|
||||
^^^^^^^^
|
||||
|
||||
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for the purposes of training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment, which includes:
|
||||
|
||||
- The ability to model a relevant platform / system context;
|
||||
- Modelling an adversarial agent that the defensive agent can be trained and evaluated against;
|
||||
- The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, operating systems, services and traffic loading on links;
|
||||
- Modelling background pattern-of-life;
|
||||
- Operates at machine-speed to enable fast training cycles.
|
||||
|
||||
Features
|
||||
^^^^^^^^
|
||||
|
||||
PrimAITE incorporates the following features:
|
||||
|
||||
- Highly configurable (via YAML files) to provide the means to model a variety of platform / system laydowns, mission profiles and adversarial attack scenarios;
|
||||
- A Reinforcement Learning (RL) reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure mission success;
|
||||
- Provision of logging to support AI performance / effectiveness assessment;
|
||||
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life, adversarial behaviour and mission data (on a sliding scale of criticality);
|
||||
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
|
||||
- Application of traffic to the links of the platform / system laydown adheres to the ACL ruleset;
|
||||
- Presents both an OpenAI gym and Ray RLLib interface to the environment, allowing integration with any compliant defensive agents;
|
||||
- Allows for the saving and loading of trained defensive agents;
|
||||
- Stochastic adversarial agent behaviour;
|
||||
- Full capture of discrete logs relating to agent training or evaluation (system state, agent actions taken, instantaneous and average reward for every step of every episode);
|
||||
- Distinct control over running a training and / or evaluation session;
|
||||
- NetworkX provides laydown visualisation capability.
|
||||
|
||||
Architecture
|
||||
^^^^^^^^^^^^
|
||||
|
||||
PrimAITE is a Python application and is therefore Operating System agnostic. The OpenAI gym and Ray RLLib frameworks are employed to provide an interface and source for AI agents. Configuration of PrimAITE is achieved via included YAML files which support full control over the platform / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count. NetworkX based nodes and links host Python classes to present attributes and methods, and hence a more representative platform / system can be modelled within the simulation.
|
||||
|
||||
|
||||
|
||||
Training & Evaluation Capability
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its OpenAI Gym and RLLib compliant interface. Scenarios can be constructed to reflect platform / system laydowns consisting of any configuration of nodes (e.g. PCs, servers, switches etc.) and network links between them. All nodes can be configured to model services (and their status) and the traffic loading between them over the network links. Traffic loading is broken down into a per service granularity, relating directly to a protocol (e.g. Service A would be configured as a TCP service, and TCP traffic then flows between instances of Service A under the direction of a tailored IER). Highlights of PrimAITE’s training and evaluation capability are:
|
||||
|
||||
- The scenario is not bound to a representation of any platform, system, technology or mission;
|
||||
- Fully configurable (network / system laydown, IERs, node pattern-of-life, ACL, number of episodes, steps per episode) and repeatable to suit the requirements of AI agents;
|
||||
- Can integrate with any OpenAI Gym or RLLib compliant AI agent.
|
||||
|
||||
Use of PrimAITE default scenarios within ARCD is supported by a “Use Case Profile” tailored to the scenario.
|
||||
|
||||
AI Assessment Capability
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
PrimAITE includes the capability to support in-depth assessment of cyber defence AI by outputting logs of the environment state and AI behaviour throughout both training and evaluation sessions. These logs include the following data:
|
||||
|
||||
- Timestamp;
|
||||
- Episode and step number;
|
||||
- Agent identifier;
|
||||
- Observation space;
|
||||
- Action taken (by defensive AI);
|
||||
- Reward value.
|
||||
|
||||
Logs are available in CSV format and provide coverage of the above data for every step of every episode.
|
||||
|
||||
|
||||
|
||||
|
||||
What is PrimAITE built with
|
||||
---------------------------
|
||||
|
||||
* `OpenAI's Gym <https://gym.openai.com/>`_ is used as the basis for AI blue agent interaction with the PrimAITE environment
|
||||
* `Networkx <https://github.com/networkx/networkx>`_ is used as the underlying data structure used for the PrimAITE environment
|
||||
* `Stable Baselines 3 <https://github.com/DLR-RM/stable-baselines3>`_ is used as a default source of RL algorithms (although PrimAITE is not limited to SB3 agents)
|
||||
* `Ray RLlib <https://github.com/ray-project/ray>`_ is used as an additional source of RL algorithms
|
||||
* `Typer <https://github.com/tiangolo/typer>`_ is used for building CLIs (Command Line Interface applications)
|
||||
* `Jupyterlab <https://github.com/jupyterlab/jupyterlab>`_ is used as an extensible environment for interactive and reproducible computing, based on the Jupyter Notebook Architecture
|
||||
* `Platformdirs <https://github.com/platformdirs/platformdirs>`_ is used for finding the right location to store user data and configuration but varies per platform
|
||||
* `Plotly <https://github.com/plotly/plotly.py>`_ is used for building high level charts
|
||||
|
||||
|
||||
Getting Started with PrimAITE
|
||||
-----------------------------
|
||||
|
||||
Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 8
|
||||
:caption: Contents:
|
||||
:hidden:
|
||||
|
||||
source/getting_started
|
||||
source/about
|
||||
source/config
|
||||
source/primaite_session
|
||||
source/custom_agent
|
||||
PrimAITE API <source/_autosummary/primaite>
|
||||
PrimAITE Tests <source/_autosummary/tests>
|
||||
source/dependencies
|
||||
source/glossary
|
||||
source/migration_1.2_-_2.0
|
||||
|
||||
|
||||
.. TODO: Add project links once public repo has been created
|
||||
|
||||
.. toctree::
|
||||
:caption: Project Links:
|
||||
:hidden:
|
||||
|
||||
Code <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE>
|
||||
Issues <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues>
|
||||
Pull Requests <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/pulls>
|
||||
Discussions <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/discussions>
|
||||
58
docs/make.bat
Normal file
58
docs/make.bat
Normal file
@@ -0,0 +1,58 @@
|
||||
@ECHO OFF
|
||||
|
||||
setlocal EnableDelayedExpansion
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
set AUTOSUMMARYDIR="%cd%\source\_autosummary\"
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
REM delete autosummary if it exists
|
||||
|
||||
IF EXIST %AUTOSUMMARYDIR% (
|
||||
echo deleting %AUTOSUMMARYDIR%
|
||||
RMDIR %AUTOSUMMARYDIR% /s /q
|
||||
)
|
||||
|
||||
REM print the YT licenses
|
||||
set LICENSEBUILD=pip-licenses --format=rst --with-urls
|
||||
set DEPS="%cd%\source\primaite-dependencies.rst"
|
||||
|
||||
%LICENSEBUILD% --output-file=%DEPS%
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:clean
|
||||
IF EXIST %AUTOSUMMARYDIR% (
|
||||
echo deleting %AUTOSUMMARYDIR%
|
||||
RMDIR %AUTOSUMMARYDIR% /s /q
|
||||
)
|
||||
|
||||
:end
|
||||
popd
|
||||
414
docs/source/about.rst
Normal file
414
docs/source/about.rst
Normal file
@@ -0,0 +1,414 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
About PrimAITE
|
||||
==============
|
||||
|
||||
Features
|
||||
********
|
||||
|
||||
PrimAITE provides the following features:
|
||||
|
||||
* A flexible network / system laydown based on the Python networkx framework
|
||||
* Nodes and links (edges) host Python classes in order to present attributes and methods (and hence, a more representative model of a platform / system)
|
||||
* A 'green agent' Information Exchange Requirement (IER) function allows the representation of traffic (protocols and loading) on any / all links. Application of IERs is based on the status of node operating systems and services
|
||||
* A 'green agent' node Pattern-of-Life (PoL) function allows the representation of core behaviours on nodes (e.g. changing the Hardware state, Software State, Service state, or File System state)
|
||||
* An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP, destination IP, protocol and port). Application of IERs adheres to any ACL restrictions
|
||||
* Presents an OpenAI Gym interface to the environment, allowing integration with any OpenAI Gym compliant defensive agents
|
||||
* Red agent activity based on 'red' IERs and 'red' PoL
|
||||
* Defined reward function for use with RL agents (based on nodes status, and green / red IER success)
|
||||
* Fully configurable (network / system laydown, IERs, node PoL, ACL, episode step period, episode max steps) and repeatable to suit the training requirements of agents. Therefore, not bound to a representation of any particular platform, system or technology
|
||||
* Full capture of discrete metrics relating to agent training (full system state, agent actions taken, average reward)
|
||||
* Networkx provides laydown visualisation capability
|
||||
|
||||
Architecture - Nodes and Links
|
||||
******************************
|
||||
|
||||
**Nodes**
|
||||
|
||||
An inheritance model has been adopted in order to model nodes. All nodes have the following base attributes (Class: Node):
|
||||
|
||||
* ID
|
||||
* Name
|
||||
* Type (e.g. computer, switch, RTU - enumeration)
|
||||
* Priority (P1, P2, P3, P4 or P5 - enumeration)
|
||||
* Hardware State (ON, OFF, RESETTING, SHUTTING_DOWN, BOOTING - enumeration)
|
||||
|
||||
Active Nodes also have the following attributes (Class: Active Node):
|
||||
|
||||
* IP Address
|
||||
* Software State (GOOD, PATCHING, COMPROMISED - enumeration)
|
||||
* File System State (GOOD, CORRUPT, DESTROYED, REPAIRING, RESTORING - enumeration)
|
||||
|
||||
Service Nodes also have the following attributes (Class: Service Node):
|
||||
|
||||
* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type)
|
||||
* Service state (GOOD, PATCHING, COMPROMISED, OVERWHELMED - enumeration)
|
||||
|
||||
Passive Nodes are currently not used (but may be employed for non IP-based components such as machinery actuators in future releases).
|
||||
|
||||
**Links**
|
||||
|
||||
Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality. Links include the following attributes:
|
||||
|
||||
* ID
|
||||
* Name
|
||||
* Bandwidth (bits/s)
|
||||
* Source node ID
|
||||
* Destination node ID
|
||||
* Protocol list (containing the loading of protocols currently running on the link)
|
||||
|
||||
When the simulation runs, IERs are applied to the links in order to model traffic loading, individually assigned to each protocol. This allows green (background) and red agent behaviour to be modelled, and defensive agents to identify suspicious traffic patterns at a protocol / traffic loading level of fidelity.
|
||||
|
||||
Information Exchange Requirements (IERs)
|
||||
****************************************
|
||||
|
||||
PrimAITE adopts the concept of Information Exchange Requirements (IERs) to model both green agent (background) and red agent (adversary) behaviour. IERs are used to initiate modelling of traffic loading on the network, and have the following attributes:
|
||||
|
||||
* ID
|
||||
* Start step (i.e. which step in the training episode should the IER start)
|
||||
* End step (i.e. which step in the training episode should the IER end)
|
||||
* Source node ID
|
||||
* Destination node ID
|
||||
* Load (bits/s)
|
||||
* Protocol
|
||||
* Port
|
||||
* Running status (i.e. on / off)
|
||||
|
||||
The application of green agent IERs between a source and destination follows a number of rules. Specifically:
|
||||
|
||||
1. Does the current simulation time step fall between IER start and end step
|
||||
2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
|
||||
3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
|
||||
4. Are there any Access Control List rules in place that prevent the application of this IER
|
||||
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
|
||||
|
||||
For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically:
|
||||
|
||||
1. Does the current simulation time step fall between IER start and end step
|
||||
2. Is the source node operational, and is the service (protocol / port) associated with the IER (a) present on that node and (b) already in a compromised state
|
||||
3. Is the destination node operational, and is the service (protocol / port) associated with the IER present on that node
|
||||
4. Are there any Access Control List rules in place that prevent the application of this IER
|
||||
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
|
||||
|
||||
Assuming the rules pass, the IER is applied to all relevant links (based on use of OSPF) between source and destination.
|
||||
|
||||
Node Pattern-of-Life
|
||||
********************
|
||||
|
||||
Every node can be impacted (i.e. have a status change applied to it) by either green agent pattern-of-life or red agent pattern-of-life. This is distinct from IERs, and allows for attacks (and defence) to be modelled purely within the confines of a node.
|
||||
|
||||
The status changes that can be made to a node are as follows:
|
||||
|
||||
* All Nodes:
|
||||
|
||||
* Hardware State:
|
||||
|
||||
* ON
|
||||
* OFF
|
||||
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
|
||||
* BOOTING
|
||||
* SHUTTING_DOWN
|
||||
|
||||
* Active Nodes and Service Nodes:
|
||||
|
||||
* Software State:
|
||||
|
||||
* GOOD
|
||||
* PATCHING - when a status of patching is entered, the node will automatically exit this state after a number of steps (as defined by the osPatchingDuration configuration item) after which it returns to a GOOD state
|
||||
* COMPROMISED
|
||||
|
||||
* File System State:
|
||||
|
||||
* GOOD
|
||||
* CORRUPT (can be resolved by repair or restore)
|
||||
* DESTROYED (can be resolved by restore only)
|
||||
* REPAIRING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRepairingLimit configuration item) after which it returns to a GOOD state
|
||||
* RESTORING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRestoringLimit configuration item) after which it returns to a GOOD state
|
||||
|
||||
* Service Nodes only:
|
||||
|
||||
* Service State (for any associated service):
|
||||
|
||||
* GOOD
|
||||
* PATCHING - when a status of patching is entered, the service will automatically exit this state after a number of steps (as defined by the servicePatchingDuration configuration item) after which it returns to a GOOD state
|
||||
* COMPROMISED
|
||||
* OVERWHELMED
|
||||
|
||||
Red agent pattern-of-life has an additional feature not found in the green pattern-of-life. This is the ability to influence the state of the attributes of a node via a number of different conditions:
|
||||
|
||||
* DIRECT:
|
||||
|
||||
The pattern-of-life described by the configuration file item will be applied regardless of any other conditions in the network. This is particularly useful for direct red agent entry into the network.
|
||||
|
||||
* IER:
|
||||
|
||||
The pattern-of-life described by the configuration file item will be applied to the service on the node, only if there is an IER of the same protocol / service type incoming at the specified timestep.
|
||||
|
||||
* SERVICE:
|
||||
|
||||
The pattern-of-life described by the configuration file item will be applied to the node based on the state of a service. The service can either be on the same node, or a different node within the network.
|
||||
|
||||
Access Control List modelling
|
||||
*****************************
|
||||
|
||||
An Access Control List (ACL) is modelled to provide the means to manage traffic flows in the system. This will allow defensive agents the means to turn on / off rules, or potentially create new rules, to counter an attack.
|
||||
|
||||
The ACL follows a standard network firewall format. For example:
|
||||
|
||||
.. list-table:: ACL example
|
||||
:widths: 25 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Permission
|
||||
- Source IP
|
||||
- Dest IP
|
||||
- Protocol
|
||||
- Port
|
||||
* - DENY
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
- HTTPS
|
||||
- 443
|
||||
* - ALLOW
|
||||
- 192.168.1.4
|
||||
- ANY
|
||||
- SMTP
|
||||
- 25
|
||||
* - DENY
|
||||
- ANY
|
||||
- 192.168.1.5
|
||||
- ANY
|
||||
- ANY
|
||||
|
||||
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry.
|
||||
|
||||
Observation Spaces
|
||||
******************
|
||||
The observation space provides the blue agent with information about the current status of nodes and links.
|
||||
|
||||
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationsHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
|
||||
|
||||
NodeLinkTable component
|
||||
-----------------------
|
||||
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
|
||||
|
||||
An example observation space is provided below:
|
||||
|
||||
.. list-table:: Observation Space example
|
||||
:widths: 25 25 25 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* -
|
||||
- ID
|
||||
- Hardware State
|
||||
- Software State
|
||||
- File System State
|
||||
- Service / Protocol A
|
||||
- Service / Protocol B
|
||||
* - Node A
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
* - Node B
|
||||
- 2
|
||||
- 1
|
||||
- 3
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
* - Node C
|
||||
- 3
|
||||
- 2
|
||||
- 1
|
||||
- 1
|
||||
- 3
|
||||
- 2
|
||||
* - Link 1
|
||||
- 5
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 10000
|
||||
* - Link 2
|
||||
- 6
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 10000
|
||||
* - Link 3
|
||||
- 7
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 5000
|
||||
- 0
|
||||
|
||||
For the nodes, the following values are represented:
|
||||
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
ID
|
||||
Hardware State (1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
|
||||
Operating System State (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||
File System State (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
|
||||
Service1/Protocol1 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||
Service2/Protocol2 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||
]
|
||||
|
||||
(Note that each service available in the network is provided as a column, although not all nodes may utilise all services)
|
||||
|
||||
For the links, the following statuses are represented:
|
||||
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
ID
|
||||
Hardware State (0=not applicable)
|
||||
Operating System State (0=not applicable)
|
||||
File System State (0=not applicable)
|
||||
Service1/Protocol1 state (Traffic load from this protocol on this link)
|
||||
Service2/Protocol2 state (Traffic load from this protocol on this link)
|
||||
]
|
||||
|
||||
NodeStatus component
|
||||
----------------------
|
||||
This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states.
|
||||
The example above would have the following structure:
|
||||
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
node1_info
|
||||
node2_info
|
||||
node3_info
|
||||
]
|
||||
|
||||
Each ``node_info`` contains the following:
|
||||
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
hardware_state (0=none, 1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
|
||||
software_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||
file_system_state (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
|
||||
service1_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||
]
|
||||
|
||||
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4])
|
||||
|
||||
.. note::
|
||||
NodeStatus observation component provides information only about nodes. Links are not considered.
|
||||
|
||||
LinkTrafficLevels
|
||||
-----------------
|
||||
This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value.
|
||||
There are two configurable parameters:
|
||||
* ``quantisation_levels`` determines how many discrete bins to use for converting the continuous traffic value to discrete (default is 5).
|
||||
* ``combine_service_traffic`` determines whether to separately output traffic use for each network protocol or whether to combine them into an overall value for the link. (default is ``True``)
|
||||
|
||||
For example, with default parameters and a network with three links, the structure of this component would be:
|
||||
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
link1_status
|
||||
link2_status
|
||||
link3_status
|
||||
]
|
||||
|
||||
Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth.
|
||||
|
||||
.. code-block::
|
||||
|
||||
0 = No traffic (0%)
|
||||
1 = low traffic (1%-33%)
|
||||
2 = medium traffic (33%-66%)
|
||||
3 = high traffic (66%-99%)
|
||||
4 = max traffic/ overwhelmed (100%)
|
||||
|
||||
Using ``gym`` notation, the shape of the obs space is: ``gym.spaces.MultiDiscrete([5,5,5])``.
|
||||
|
||||
|
||||
Action Spaces
|
||||
**************
|
||||
|
||||
The action space available to the blue agent comes in two types:
|
||||
|
||||
1. Node-based
|
||||
2. Access Control List
|
||||
3. Any (Agent can take both node-based and ACL-based actions)
|
||||
|
||||
The choice of action space used during a training session is determined in the config_[name].yaml file.
|
||||
|
||||
**Node-Based**
|
||||
|
||||
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows:
|
||||
|
||||
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
|
||||
The placeholders inside the list under the key '1' mean the following:
|
||||
|
||||
* [0, num nodes] - Node ID (0 = nothing, node ID)
|
||||
* [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state)
|
||||
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
|
||||
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
|
||||
|
||||
**Access Control List**
|
||||
|
||||
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows:
|
||||
|
||||
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
|
||||
The placeholders inside the list under the key '1' mean the following:
|
||||
|
||||
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
* [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
|
||||
**ANY**
|
||||
The agent is able to carry out both **Node-Based** and **Access Control List** operations.
|
||||
|
||||
This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above.
|
||||
|
||||
Rewards
|
||||
*******
|
||||
|
||||
A reward value is presented back to the blue agent on the conclusion of every step. The reward value is calculated via two methods which combine to give the total value:
|
||||
|
||||
1. Node and service status
|
||||
2. IER status
|
||||
|
||||
**Node and service status**
|
||||
|
||||
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
|
||||
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
|
||||
difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values.
|
||||
|
||||
**IER status**
|
||||
|
||||
On every step, the full IER set is examined to determine whether green and red agent IERs are being permitted to run. Any red agent IERs running incur a penalty; any green agent
|
||||
IERs not permitted to run also incur a penalty. See :ref:`config` for details of reward values.
|
||||
|
||||
Future Enhancements
|
||||
*******************
|
||||
|
||||
The PrimAITE project has an ambition to include the following enhancements in future releases:
|
||||
|
||||
* Integration with a suitable standardised framework to allow multi-agent integration
|
||||
* Integration with external threat emulation tools, either using off-line data, or integrating at runtime
|
||||
489
docs/source/config.rst
Normal file
489
docs/source/config.rst
Normal file
@@ -0,0 +1,489 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _config:
|
||||
|
||||
The Config Files Explained
|
||||
==========================
|
||||
|
||||
PrimAITE uses two configuration files for its operation:
|
||||
|
||||
* **The Training Config**
|
||||
|
||||
Used to define the top-level settings of the PrimAITE environment, the reward values, and the session that is to be run.
|
||||
|
||||
* **The Lay Down Config**
|
||||
|
||||
Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules.
|
||||
|
||||
Training Config:
|
||||
*******************
|
||||
|
||||
The Training Config file consists of the following attributes:
|
||||
|
||||
**Generic Config Values**
|
||||
|
||||
|
||||
* **agent_framework** [enum]
|
||||
|
||||
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
|
||||
|
||||
* NONE - Where a user developed agent is to be used
|
||||
* SB3 - Stable Baselines3
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **agent_identifier**
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* A2C - Advantage Actor Critic
|
||||
* PPO - Proximal Policy Optimization
|
||||
* HARDCODED - A custom built deterministic agent
|
||||
* RANDOM - A Stochastic random agent
|
||||
|
||||
|
||||
* **random_red_agent** [bool]
|
||||
|
||||
Determines if the session should be run with a random red agent
|
||||
|
||||
* **action_type** [enum]
|
||||
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
|
||||
|
||||
* **OBSERVATION_SPACE** [dict]
|
||||
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is in :py:mod:`primaite.environment.observations`.
|
||||
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
|
||||
This example illustrates the correct format for the observation space config item
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
combine_service_traffic : False
|
||||
quantisation_levels: 99
|
||||
|
||||
|
||||
Currently available components are:
|
||||
|
||||
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
|
||||
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
|
||||
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
|
||||
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
|
||||
|
||||
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
|
||||
* ``quantisation_levels`` - how many discrete bandwidth usage levels to use for encoding. This can be an integer equal to or greater than 3.
|
||||
|
||||
The other configurable item is ``flatten`` which is false by default. When set to true, the observation space is flattened (turned into a 1-D vector). You should use this if your RL agent does not natively support observation space types like ``gym.Spaces.Tuple``.
|
||||
|
||||
* **num_train_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will train for.
|
||||
|
||||
|
||||
* **num_train_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the training session.
|
||||
|
||||
|
||||
* **num_eval_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will be evaluated over.
|
||||
|
||||
|
||||
* **num_eval_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the evaluation session.
|
||||
|
||||
|
||||
* **time_delay** [int]
|
||||
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
|
||||
|
||||
* **session_type** [text]
|
||||
|
||||
Type of session to be run (TRAINING, EVALUATION, or BOTH)
|
||||
|
||||
* **load_agent** [bool]
|
||||
|
||||
Determine whether to load an agent from file
|
||||
|
||||
* **agent_load_file** [text]
|
||||
|
||||
File path and file name of agent if you're loading one in
|
||||
|
||||
* **observation_space_high_value** [int]
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
* **implicit_acl_rule** [str]
|
||||
|
||||
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
|
||||
|
||||
* **max_number_acl_rules** [int]
|
||||
|
||||
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
|
||||
* **Generic [all_ok]** [float]
|
||||
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
|
||||
* **Node Hardware State [off_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is off
|
||||
|
||||
* **Node Hardware State [off_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Hardware State [on_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Hardware State [on_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Hardware State [resetting_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting]** [float]
|
||||
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [float]
|
||||
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [float]
|
||||
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [float]
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [good_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is good
|
||||
|
||||
* **Node File System State [good_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is good
|
||||
|
||||
* **Node File System State [good_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is good
|
||||
|
||||
* **Node File System State [good_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is good
|
||||
|
||||
* **Node File System State [repairing_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is repairing
|
||||
|
||||
* **Node File System State [repairing]** [float]
|
||||
|
||||
The score to give when the state is repairing
|
||||
|
||||
* **Node File System State [restoring_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is restoring
|
||||
|
||||
* **Node File System State [restoring]** [float]
|
||||
|
||||
The score to give when the state is restoring
|
||||
|
||||
* **Node File System State [corrupt_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt]** [float]
|
||||
|
||||
The score to give when the state is corrupt
|
||||
|
||||
* **Node File System State [destroyed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed]** [float]
|
||||
|
||||
The score to give when the state is destroyed
|
||||
|
||||
* **Node File System State [scanning]** [float]
|
||||
|
||||
The score to give when the state is scanning
|
||||
|
||||
* **IER Status [red_ier_running]** [float]
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [green_ier_blocked]** [float]
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
**Patching / Reset Durations**
|
||||
|
||||
* **os_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching an Operating System
|
||||
|
||||
* **node_reset_duration** [int]
|
||||
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
|
||||
* **service_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching a service
|
||||
|
||||
* **file_system_repairing_limit** [int]:
|
||||
|
||||
The number of steps to take when repairing the file system
|
||||
|
||||
* **file_system_restoring_limit** [int]
|
||||
|
||||
The number of steps to take when restoring the file system
|
||||
|
||||
* **file_system_scanning_limit** [int]
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
* **deterministic** [bool]
|
||||
|
||||
Set to true if the agent evaluation should be deterministic. Default is ``False``
|
||||
|
||||
* **seed** [int]
|
||||
|
||||
Seed used in the randomisation in agent training. Default is ``None``
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
|
||||
The lay down config file consists of the following attributes:
|
||||
|
||||
|
||||
* **itemType: STEPS** [int]
|
||||
|
||||
* **item_type: PORTS** [int]
|
||||
|
||||
Provides a list of ports modelled in this session
|
||||
|
||||
* **item_type: SERVICES** [freetext]
|
||||
|
||||
Provides a list of services modelled in this session
|
||||
|
||||
* **item_type: NODE**
|
||||
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
* **node_class** [enum]: Relates to the base type of the node. Can be SERVICE, ACTIVE or PASSIVE. PASSIVE nodes do not have an operating system or services. ACTIVE nodes have an operating system, but no services. SERVICE nodes have both an operating system and one or more services
|
||||
* **node_type** [enum]: Relates to the component type. Can be one of CCTV, SWITCH, COMPUTER, LINK, MONITOR, PRINTER, LOP, RTU, ACTUATOR or SERVER
|
||||
* **priority** [enum]: Provides a priority for each node. Can be one of P1, P2, P3, P4 or P5 (which P1 being the highest)
|
||||
* **hardware_state** [enum]: The initial hardware state of the node. Can be one of ON, OFF or RESETTING
|
||||
* **ip_address** [IP address]: The IP address of the component in format xxx.xxx.xxx.xxx
|
||||
* **software_state** [enum]: The intial state of the node operating system. Can be GOOD, PATCHING or COMPROMISED
|
||||
* **file_system_state** [enum]: The initial state of the node file system. Can be GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING
|
||||
* **services**: For each service associated with the node:
|
||||
|
||||
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
|
||||
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
|
||||
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **item_type: LINK**
|
||||
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
* **bandwidth** [int]: The bandwidth (in bits/s) of the link
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
|
||||
* **item_type: GREEN_IER**
|
||||
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
* **end_step** [int]: The end step (in the episode) for this IER to finish
|
||||
* **load** [int]: The load (in bits/s) for this IER to apply to links
|
||||
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
|
||||
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
* **mission_criticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
|
||||
|
||||
* **item_type: RED_IER**
|
||||
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
* **end_step** [int]: The end step (in the episode) for this IER to finish
|
||||
* **load** [int]: The load (in bits/s) for this IER to apply to links
|
||||
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
|
||||
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
|
||||
* **source** [int]: The ID of the source node
|
||||
* **destination** [int]: The ID of the destination node
|
||||
* **mission_criticality** [enum]: Not currently used. Default to 0
|
||||
|
||||
* **item_type: GREEN_POL**
|
||||
|
||||
Defines a green agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this PoL to begin
|
||||
* **end_step** [int]: Not currently used. Default to same as start step
|
||||
* **nodeId** [int]: The ID of the node to apply the PoL to
|
||||
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
|
||||
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
|
||||
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
|
||||
|
||||
* **item_type: RED_POL**
|
||||
|
||||
Defines a red agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this PoL to begin
|
||||
* **end_step** [int]: Not currently used. Default to same as start step
|
||||
* **targetNodeId** [int]: The ID of the node to apply the PoL to
|
||||
* **initiator** [enum]: What initiates the PoL. Can be DIRECT, IER or SERVICE
|
||||
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
|
||||
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
|
||||
* **state** [enum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state) or GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING (for file system state)
|
||||
* **sourceNodeId** [int] The ID of the source node containing the service to check (used for SERVICE initiator)
|
||||
* **sourceNodeService** [freetext]: The service on the source node to check (used for SERVICE initiator). Must match a value in the services list for this node
|
||||
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **item_type: ACL_RULE**
|
||||
|
||||
Defines an initial Access Control List (ACL) rule. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **permission** [enum]: Defines either an allow or deny rule. Value must be either DENY or ALLOW
|
||||
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
|
||||
142
docs/source/custom_agent.rst
Normal file
142
docs/source/custom_agent.rst
Normal file
@@ -0,0 +1,142 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
Custom Agents
|
||||
=============
|
||||
|
||||
|
||||
Integrating a user defined blue agent
|
||||
*************************************
|
||||
|
||||
.. note::
|
||||
|
||||
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
|
||||
|
||||
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
|
||||
|
||||
Below is a barebones example of a custom agent implementation:
|
||||
|
||||
.. code:: python
|
||||
|
||||
# src/primaite/agents/my_custom_agent.py
|
||||
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
|
||||
class CustomAgent(AgentSessionABC):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
assert self._training_config.agent_framework == AgentFramework.CUSTOM
|
||||
assert self._training_config.agent_identifier == AgentIdentifier.MY_AGENT
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
self._agent = ... # your code to setup agent
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_num = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
if checkpoint_num:
|
||||
save_checkpoint = episode_count % checkpoint_num == 0
|
||||
# saves checkpoint if the episode count is not 0 and save_checkpoint flag was set to true
|
||||
if episode_count and save_checkpoint:
|
||||
...
|
||||
# your code to save checkpoint goes here.
|
||||
# The path should start with self.checkpoints_path and include the episode number.
|
||||
|
||||
def learn(self):
|
||||
...
|
||||
# call your agent's learning function here.
|
||||
|
||||
super().learn() # this will finalise learning and output session metadata
|
||||
self.save()
|
||||
|
||||
def evaluate(self):
|
||||
...
|
||||
# call your agent's evaluation function here.
|
||||
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
...
|
||||
# Load an agent from file.
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
...
|
||||
# Create a CustomAgent object which loads model weights from file.
|
||||
|
||||
def save(self):
|
||||
...
|
||||
# Call your agent's function that saves it to a file
|
||||
|
||||
|
||||
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` and :py:mod:`primaite.common.enums` to capture your new agent identifiers.
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 17, 18
|
||||
|
||||
# src/primaite/common/enums.py
|
||||
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"The Hardcoded agents"
|
||||
DO_NOTHING = 4
|
||||
"The DoNothing agents"
|
||||
RANDOM = 5
|
||||
"The RandomAgent"
|
||||
DUMMY = 6
|
||||
"The DummyAgent"
|
||||
CUSTOM_AGENT = 7
|
||||
"Your custom agent"
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 3, 11, 12
|
||||
|
||||
# src/primaite_session.py
|
||||
|
||||
from primaite.agents.my_custom_agent import CustomAgent
|
||||
|
||||
# ...
|
||||
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.CUSTOM_AGENT:
|
||||
self._agent_session = CustomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
Finally, specify your agent in your training config.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
# ~/primaite/2.0.0/config/path/to/your/config_main.yaml
|
||||
|
||||
# Training Config File
|
||||
|
||||
agent_framework: CUSTOM
|
||||
agent_identifier: CUSTOM_AGENT
|
||||
random_red_agent: False
|
||||
# ...
|
||||
|
||||
Now you can :ref:`run a primaite session<run a primaite session>` with your custom agent by passing in the custom ``config_main``.
|
||||
14
docs/source/dependencies.rst
Normal file
14
docs/source/dependencies.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. role:: raw-html(raw)
|
||||
:format: html
|
||||
|
||||
Dependencies
|
||||
============
|
||||
|
||||
PrimAITE Dependencies
|
||||
---------------------
|
||||
|
||||
.. include:: primaite-dependencies.rst
|
||||
149
docs/source/getting_started.rst
Normal file
149
docs/source/getting_started.rst
Normal file
@@ -0,0 +1,149 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _getting-started:
|
||||
|
||||
Getting Started
|
||||
===============
|
||||
|
||||
**Getting Started with PrimAITE**
|
||||
|
||||
Pre-Requisites
|
||||
|
||||
In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.10 installed. If you don't already have it, this is how to install it:
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
sudo add-apt-repository ppa:deadsnakes/ppa
|
||||
sudo apt install python3.10
|
||||
sudo apt-get install python3-pip
|
||||
sudo apt-get install python3-venv
|
||||
|
||||
|
||||
.. code-block:: text
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
- Manual install from: https://www.python.org/downloads/release/python-31011/
|
||||
|
||||
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
|
||||
|
||||
Install PrimAITE
|
||||
****************
|
||||
|
||||
1. Create a primaite directory in your home directory:
|
||||
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
mkdir ~/primaite/2.0.0
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
mkdir ~\primaite\2.0.0
|
||||
|
||||
2. Navigate to the primaite directory and create a new python virtual environment (venv)
|
||||
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
cd ~/primaite/2.0.0
|
||||
python3 -m venv .venv
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
cd ~\primaite\2.0.0
|
||||
python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
|
||||
3. Activate the venv
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
.\.venv\Scripts\activate
|
||||
|
||||
|
||||
4. Install PrimAITE using pip from PyPi
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install primaite
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
pip install primaite
|
||||
|
||||
5. Perform the PrimAITE setup
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
primaite setup
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
primaite setup
|
||||
|
||||
Clone & Install PrimAITE for Development
|
||||
****************************************
|
||||
|
||||
To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
|
||||
of your choice:
|
||||
|
||||
.. TODO:: Add repo path once we know what it is
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone <repo path>
|
||||
cd primaite
|
||||
|
||||
Create and activate your Python virtual environment (venv)
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
python3 -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
Install PrimAITE with the dev extra
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install -e .[dev]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).
|
||||
81
docs/source/glossary.rst
Normal file
81
docs/source/glossary.rst
Normal file
@@ -0,0 +1,81 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
Glossary
|
||||
=============
|
||||
|
||||
.. glossary::
|
||||
:sorted:
|
||||
|
||||
Network
|
||||
The network in primaite is a logical representation of a computer network containing :term:`Nodes<Node>` and :term:`Links<Link>`.
|
||||
|
||||
Node
|
||||
A Node represents a network endpoint. For example a computer, server, switch, or an actuator.
|
||||
|
||||
Link
|
||||
A Link represents the connection between two Nodes. For example, a physical wire between a computer and a switch or a wireless connection.
|
||||
|
||||
Protocol
|
||||
Protocols are used by links to separate different types of network traffic. Common examples would be HTTP, TCP, and UDP.
|
||||
|
||||
Service
|
||||
A service represents a piece of software that is installed on a node, such as a web server or a database.
|
||||
|
||||
Access Control List
|
||||
PrimAITE blocks or allows certain traffic on the network by simulating firewall rules, which are defined in the Access Control List.
|
||||
|
||||
Agent
|
||||
An agent is a representation of a user of the network. Typically this would be a user that is using one of the computer nodes, though it could be an autonomous agent.
|
||||
|
||||
Green agent
|
||||
Simulates typical benign activity on the network, such as real users using computers and servers.
|
||||
|
||||
Red Agent
|
||||
An agent that is aiming to attack the network in some way, for example by executing a Denial-Of-Service attack or stealing data.
|
||||
|
||||
Blue Agent
|
||||
A defensive agent that protects the network from Red Agent attacks to minimise disruption to green agents and protect data.
|
||||
|
||||
Information Exchange Requirement (IER)
|
||||
Simulates network traffic by sending data from one network node to another via links for a specified amount of time. IERs can be part of green agent behaviour or red agent behaviour. PrimAITE can be configured to apply a penalty for green agents' IERs being blocked and a reward for red agents' IERs being blocked.
|
||||
|
||||
Pattern-of-Life (PoL)
|
||||
PoLs allow agents to change the current hardware, OS, file system, or service statuses of nodes during the course of an episode. For example, a green agent may restart a server node to represent scheduled maintainance. A red agent's Pattern-of-Life can be used to attack nodes by changing their states to CORRUPTED or COMPROMISED.
|
||||
|
||||
Reward
|
||||
The reward is a single number used by the blue agent to understand whether it's performing well or poorly. RL agents change their behaviour in an attempt to increase the expected reward each episode. The reward is generated based on the current states of the environment / :term:`reference environment` and is impacted positively by things like green IERS running successfully and negatively by things like nodes being compromised.
|
||||
|
||||
Observation
|
||||
An observation is a representation of the current state of the environment that is given to the learning agent so it can decide on which action to perform. If the environment is 'fully observable', the observation contains information about every possible aspect of the environment. More commonly, the environment is 'partially observable' which means the learning agent has to make decisions without knowing every detail of the current environment state.
|
||||
|
||||
Action
|
||||
The learning agent decides on an action to take on every step in the simulation. The action has the chance to positively or negatively impact the environment state. Over time, the agent aims to learn which actions to take when to maximise the expected reward.
|
||||
|
||||
Training
|
||||
During training, an RL agent is placed in the simulated network and it learns which actions to take in which scenarios to obtain maximum reward.
|
||||
|
||||
Evaluation
|
||||
During evaluation, an RL agent acts on the simulated network but it is not allowed to update it's behaviour. Evaluation is used to assess how successful agents are at defending the network.
|
||||
|
||||
Step
|
||||
The agents can only act in the environment at discrete intervals. The time step is the basic unit of time in the simulation. At each step, the RL agent has an opportunity to observe the state of the environment and decide an action. Steps are also used for updating states for time-dependent activities such as rebooting a node.
|
||||
|
||||
Episode
|
||||
When an episode starts, the network simulation is reset to an initial state. The agents take actions on each step of the episode until it reaches a terminal state, which usually happens after a predetermined number of steps. After the terminal state is reached, a new episode starts and the RL agent has another opportunity to protect the network.
|
||||
|
||||
Reference environment
|
||||
While the network simulation is unfolding, a parallel simulation takes place which is identical to the main one except that blue and red agent actions are not applied. This reference environment essentially shows what would be happening to the network if there had been no cyberattack or defense. The reference environment is used to calculate rewards.
|
||||
|
||||
Transaction
|
||||
PrimAITE records the decisions of the learning agent by saving its observation, action, and reward at every time step. During each session, this data is saved to disk to allow for full inspection.
|
||||
|
||||
Laydown
|
||||
The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours.
|
||||
|
||||
Gym
|
||||
PrimAITE uses the Gym reinforcement learning framework API to create a training environment and interface with RL agents. Gym defines a common way of creating observations, actions, and rewards.
|
||||
|
||||
User app home
|
||||
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\Users\<username>\primaite\<version>` on Windows.
|
||||
57
docs/source/migration_1.2_-_2.0.rst
Normal file
57
docs/source/migration_1.2_-_2.0.rst
Normal file
@@ -0,0 +1,57 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
v1.2 to v2.0 Migration guide
|
||||
============================
|
||||
|
||||
**1. Installing PrimAITE**
|
||||
|
||||
Like before, you can install primaite from the repository by running ``pip install -e .``. But, there is now an additional setup step which does several things, like setting up user directories, copy default configs and notebooks, etc. Once you have installed PrimAITE to your virtual environment, run this command to finalise setup.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
primaite setup
|
||||
|
||||
**2. Running a training session**
|
||||
|
||||
In version 1.2 of PrimAITE, the main entry point for training or evaluating agents was the ``src/primaite/main.py`` file. v2.0.0 introduced managed 'sessions' which are responsible for reading configuration files, performing training, and writing outputs.
|
||||
|
||||
``main.py`` file still runs a training session but it now uses the new `PrimaiteSession`, and it now requires you to provide the path to your config files.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python src/primaite/main.py --tc path/to/training-config.yaml --ldc path/to/laydown-config.yaml
|
||||
|
||||
Alternatively, the session can be invoked via the commandline by running:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
primaite session --tc path/to/training-config.yaml --ldc path/to/laydown-config.yaml
|
||||
|
||||
**3. Location of configs**
|
||||
|
||||
In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/2.0.0/config``. On Windows, this is stored in ``C:\Users\<your username>\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here.
|
||||
|
||||
**4. Contents of configs**
|
||||
|
||||
Some things that were previously part of the laydown config are now part of the traning config.
|
||||
|
||||
* Actions
|
||||
|
||||
If you have custom configs which use these, you will need to adapt them by moving the configuration from the laydown config to the training config.
|
||||
|
||||
Also, there are new configurable items in the training config:
|
||||
|
||||
* Observations
|
||||
* Agent framework
|
||||
* Agent
|
||||
* Deep learning framework
|
||||
* random red agents
|
||||
* seed
|
||||
* deterministic
|
||||
* hard coded agent view
|
||||
|
||||
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
|
||||
|
||||
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.
|
||||
212
docs/source/primaite_session.rst
Normal file
212
docs/source/primaite_session.rst
Normal file
@@ -0,0 +1,212 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _run a primaite session:
|
||||
|
||||
Run a PrimAITE Session
|
||||
======================
|
||||
|
||||
Run
|
||||
---
|
||||
|
||||
A PrimAITE session can be ran either with the ``primaite session`` command from the cli
|
||||
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
|
||||
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
|
||||
|
||||
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix CLI
|
||||
|
||||
cd ~/primaite/2.0.0
|
||||
source ./.venv/bin/activate
|
||||
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Powershell CLI
|
||||
|
||||
cd ~\primaite\2.0.0
|
||||
.\.venv\Scripts\activate
|
||||
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Python
|
||||
|
||||
from primaite.main import run
|
||||
|
||||
training_config = <path to training config yaml file>
|
||||
lay_down_config = <path to lay down config yaml file>
|
||||
run(training_config, lay_down_config)
|
||||
|
||||
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``).
|
||||
The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
|
||||
|
||||
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
|
||||
``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
||||
|
||||
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
|
||||
|
||||
To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options.
|
||||
|
||||
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix CLI
|
||||
|
||||
cd ~/primaite/2.0.0
|
||||
source ./.venv/bin/activate
|
||||
primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Powershell CLI
|
||||
|
||||
cd ~\primaite\2.0.0
|
||||
.\.venv\Scripts\activate
|
||||
primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Python
|
||||
|
||||
from primaite.main import run
|
||||
|
||||
training_config = <path to legacy training config yaml file>
|
||||
lay_down_config = <path to legacy lay down config yaml file>
|
||||
run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True)
|
||||
|
||||
|
||||
|
||||
|
||||
Outputs
|
||||
-------
|
||||
|
||||
PrimAITE produces four types of outputs:
|
||||
|
||||
* Session Metadata
|
||||
* Results
|
||||
* Diagrams
|
||||
* Saved agents (training checkpoints and a final trained agent)
|
||||
|
||||
|
||||
**Session Metadata**
|
||||
|
||||
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
|
||||
|
||||
* **uuid** - The UUID assigned to the session upon instantiation.
|
||||
* **start_datetime** - The date & time the session started in iso format.
|
||||
* **end_datetime** - The date & time the session ended in iso format.
|
||||
* **learning**
|
||||
* **total_episodes** - The total number of training episodes completed.
|
||||
* **total_time_steps** - The total number of training time steps completed.
|
||||
* **evaluation**
|
||||
* **total_episodes** - The total number of evaluation episodes completed.
|
||||
* **total_time_steps** - The total number of evaluation time steps completed.
|
||||
* **env**
|
||||
* **training_config**
|
||||
* **All training config items**
|
||||
* **lay_down_config**
|
||||
* **All lay down config items**
|
||||
|
||||
|
||||
**Results**
|
||||
|
||||
PrimAITE automatically creates two sets of results from each learning and evaluation session:
|
||||
|
||||
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
|
||||
* All transactions - a csv file listing the following values for every step of every episode:
|
||||
|
||||
* Timestamp
|
||||
* Episode number
|
||||
* Step number
|
||||
* Reward value
|
||||
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
* Initial observation space (what the blue agent observed when it decided its action)
|
||||
|
||||
**Diagrams**
|
||||
|
||||
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
|
||||
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
|
||||
|
||||
**Saved agents**
|
||||
|
||||
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
|
||||
|
||||
**Example Session Directory Structure**
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
~/
|
||||
└── primaite/
|
||||
└── 2.0.0/
|
||||
└── sessions/
|
||||
└── 2023-07-18/
|
||||
└── 2023-07-18_11-06-04/
|
||||
├── evaluation/
|
||||
│ ├── all_transactions_2023-07-18_11-06-04.csv
|
||||
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
|
||||
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
|
||||
├── learning/
|
||||
│ ├── all_transactions_2023-07-18_11-06-04.csv
|
||||
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
|
||||
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
|
||||
│ ├── checkpoints/
|
||||
│ │ └── sb3ppo_10.zip
|
||||
│ ├── SB3_PPO.zip
|
||||
│ └── tensorboard_logs/
|
||||
│ ├── PPO_1/
|
||||
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
|
||||
│ ├── PPO_2/
|
||||
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
|
||||
│ ├── PPO_3/
|
||||
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
|
||||
│ ├── PPO_4/
|
||||
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
|
||||
│ ├── PPO_5/
|
||||
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
|
||||
│ ├── PPO_6/
|
||||
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
|
||||
│ ├── PPO_7/
|
||||
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
|
||||
│ ├── PPO_8/
|
||||
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
|
||||
│ ├── PPO_9/
|
||||
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
|
||||
│ └── PPO_10/
|
||||
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
|
||||
├── network_2023-07-18_11-06-04.png
|
||||
└── session_metadata.json
|
||||
|
||||
Loading a session
|
||||
-----------------
|
||||
|
||||
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
|
||||
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
|
||||
|
||||
.. tabs::
|
||||
|
||||
.. code-tab:: bash
|
||||
:caption: Unix CLI
|
||||
|
||||
cd ~/primaite/2.0.0
|
||||
source ./.venv/bin/activate
|
||||
primaite session --load "path/to/session"
|
||||
|
||||
.. code-tab:: bash
|
||||
:caption: Powershell CLI
|
||||
|
||||
cd ~\primaite\2.0.0
|
||||
.\.venv\Scripts\activate
|
||||
primaite session --load "path\to\session"
|
||||
|
||||
|
||||
.. code-tab:: python
|
||||
:caption: Python
|
||||
|
||||
from primaite.main import run
|
||||
|
||||
run(session_path=<previous session directory>)
|
||||
|
||||
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
|
||||
88
pyproject.toml
Normal file
88
pyproject.toml
Normal file
@@ -0,0 +1,88 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "setuptools-scm", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "primaite"
|
||||
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
|
||||
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
|
||||
license = {file = "LICENSE"}
|
||||
requires-python = ">=3.8, <3.11"
|
||||
dynamic = ["version", "readme"]
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Unix",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"gym==0.21.0",
|
||||
"jupyterlab==3.6.1",
|
||||
"kaleido==0.2.1",
|
||||
"matplotlib==3.7.1",
|
||||
"networkx==3.1",
|
||||
"numpy==1.23.5",
|
||||
"platformdirs==3.5.1",
|
||||
"plotly==5.15.0",
|
||||
"polars==0.18.4",
|
||||
"PyYAML==6.0",
|
||||
"ray[rllib]==2.2.0",
|
||||
"stable-baselines3==1.6.2",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0"
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {file = ["src/primaite/VERSION"]}
|
||||
readme = {file = ["README.md"]}
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
include-package-data = true
|
||||
license-files = ["LICENSE"]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"build==0.10.0",
|
||||
"flake8==6.0.0",
|
||||
"furo==2023.3.27",
|
||||
"gputil==1.4.0",
|
||||
"pip-licenses==4.3.0",
|
||||
"pre-commit==2.20.0",
|
||||
"pylatex==1.4.1",
|
||||
"pytest==7.2.0",
|
||||
"pytest-xdist==3.3.1",
|
||||
"pytest-cov==4.0.0",
|
||||
"pytest-flake8==1.1.1",
|
||||
"setuptools==66",
|
||||
"Sphinx==6.1.3",
|
||||
"sphinx-copybutton==0.5.2",
|
||||
"wheel==0.38.4"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
primaite = "primaite.cli:app"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 120
|
||||
force_sort_within_sections = "False"
|
||||
order_by_type = "False"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE"
|
||||
Documentation = "https://Autonomous-Resilient-Cyber-Defence.github.io/PrimAITE/"
|
||||
Repository = "https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE"
|
||||
Changelog = "https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/blob/dev/CHANGELOG.md"
|
||||
5
pytest.ini
Normal file
5
pytest.ini
Normal file
@@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
testpaths =
|
||||
tests
|
||||
markers =
|
||||
env_config_paths
|
||||
4
setup.cfg
Normal file
4
setup.cfg
Normal file
@@ -0,0 +1,4 @@
|
||||
[metadata]
|
||||
url = https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE
|
||||
author = Defence Science and Technology Laboratory UK
|
||||
author_email = oss@dstl.gov.uk
|
||||
17
setup.py
Normal file
17
setup.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from setuptools import setup
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel # noqa
|
||||
|
||||
|
||||
class bdist_wheel(_bdist_wheel): # noqa
|
||||
def finalize_options(self): # noqa
|
||||
super().finalize_options()
|
||||
# Set to False if you need to build OS and Python specific wheels
|
||||
self.root_is_pure = True # noqa
|
||||
|
||||
|
||||
setup(
|
||||
cmdclass={
|
||||
"bdist_wheel": bdist_wheel,
|
||||
}
|
||||
)
|
||||
1
src/primaite/VERSION
Normal file
1
src/primaite/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
2.0.0
|
||||
207
src/primaite/__init__.py
Normal file
207
src/primaite/__init__.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import logging
|
||||
import logging.config
|
||||
import shutil
|
||||
import sys
|
||||
from bisect import bisect
|
||||
from logging import Formatter, Logger, LogRecord, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, List
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
from platformdirs import PlatformDirs
|
||||
|
||||
with open(Path(__file__).parent.resolve() / "VERSION", "r") as file:
|
||||
__version__ = file.readline().strip()
|
||||
|
||||
|
||||
class _PrimaitePaths:
|
||||
"""
|
||||
A Primaite paths class that leverages PlatformDirs.
|
||||
|
||||
The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
|
||||
|
||||
def _get_dirs_properties(self) -> List[str]:
|
||||
class_items = self.__class__.__dict__.items()
|
||||
return [k for k, v in class_items if isinstance(v, property)]
|
||||
|
||||
def mkdirs(self):
|
||||
"""
|
||||
Creates all Primaite directories.
|
||||
|
||||
Does this by retrieving all properties in the PrimaiteDirs class and calls each one.
|
||||
"""
|
||||
for p in self._get_dirs_properties():
|
||||
getattr(self, p)
|
||||
|
||||
@property
|
||||
def user_home_path(self) -> Path:
|
||||
"""The PrimAITE user home path."""
|
||||
path = Path.home() / "primaite" / __version__
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_sessions_path(self) -> Path:
|
||||
"""The PrimAITE user sessions path."""
|
||||
path = self.user_home_path / "sessions"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_config_path(self) -> Path:
|
||||
"""The PrimAITE user config path."""
|
||||
path = self.user_home_path / "config"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_notebooks_path(self) -> Path:
|
||||
"""The PrimAITE user notebooks path."""
|
||||
path = self.user_home_path / "notebooks"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_home_path(self) -> Path:
|
||||
"""The PrimAITE app home path."""
|
||||
path = self._dirs.user_data_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_dir_path(self) -> Path:
|
||||
"""The PrimAITE app config directory path."""
|
||||
path = self._dirs.user_config_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_file_path(self) -> Path:
|
||||
"""The PrimAITE app config file path."""
|
||||
return self.app_config_dir_path / "primaite_config.yaml"
|
||||
|
||||
@property
|
||||
def app_log_dir_path(self) -> Path:
|
||||
"""The PrimAITE app log directory path."""
|
||||
if sys.platform == "win32":
|
||||
path = self.app_home_path / "logs"
|
||||
else:
|
||||
path = self._dirs.user_log_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_log_file_path(self) -> Path:
|
||||
"""The PrimAITE app log file path."""
|
||||
return self.app_log_dir_path / "primaite.log"
|
||||
|
||||
def __repr__(self):
|
||||
properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()])
|
||||
return f"{self.__class__.__name__}({properties_str})"
|
||||
|
||||
|
||||
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
|
||||
|
||||
|
||||
def _host_primaite_config():
|
||||
if not PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
|
||||
|
||||
|
||||
_host_primaite_config()
|
||||
|
||||
|
||||
def _get_primaite_config() -> Dict:
|
||||
config_path = PRIMAITE_PATHS.app_config_file_path
|
||||
if not config_path.exists():
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
with open(config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
log_level_map = {
|
||||
"NOTSET": logging.NOTSET,
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
|
||||
_PRIMAITE_CONFIG = _get_primaite_config()
|
||||
|
||||
|
||||
class _LevelFormatter(Formatter):
|
||||
"""
|
||||
A custom level-specific formatter.
|
||||
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one")
|
||||
|
||||
self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items())
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""Overrides ``Formatter.format``."""
|
||||
idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1)
|
||||
level, formatter = self.formats[idx]
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
|
||||
}
|
||||
)
|
||||
|
||||
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
|
||||
|
||||
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
filename=PRIMAITE_PATHS.app_log_file_path,
|
||||
maxBytes=10485760, # 10MB
|
||||
backupCount=9, # Max 100MB of logs
|
||||
encoding="utf8",
|
||||
)
|
||||
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
|
||||
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_LOGGER.addHandler(_STREAM_HANDLER)
|
||||
_LOGGER.addHandler(_FILE_HANDLER)
|
||||
|
||||
|
||||
def getLogger(name: str) -> Logger: # noqa
|
||||
"""
|
||||
Get a PrimAITE logger.
|
||||
|
||||
:param name: The logger name. Use ``__name__``.
|
||||
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
|
||||
logging config.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
|
||||
return logger
|
||||
2
src/primaite/acl/__init__.py
Normal file
2
src/primaite/acl/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Access Control List. Models firewall functionality."""
|
||||
198
src/primaite/acl/access_control_list.py
Normal file
198
src/primaite/acl/access_control_list.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
import logging
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
# Maximum number of ACL Rules in ACL
|
||||
self.max_acl_rules: int = max_acl_rules
|
||||
# A list of ACL Rules
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
|
||||
:param _rule: The rule object to check
|
||||
:type _rule: ACLRule
|
||||
:param _source_ip_address: Source IP address to compare
|
||||
:type _source_ip_address: str
|
||||
:param _dest_ip_address: Destination IP address to compare
|
||||
:type _dest_ip_address: str
|
||||
:return: True if there is a match, otherwise False.
|
||||
:rtype: bool
|
||||
"""
|
||||
if (
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool:
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
_dest_ip_address: the destination IP address to check
|
||||
_protocol: the protocol to check
|
||||
_port: the port to check
|
||||
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule in self.acl:
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(
|
||||
self,
|
||||
_permission: RulePermissionType,
|
||||
_source_ip: str,
|
||||
_dest_ip: str,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_position: str,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
|
||||
"""
|
||||
try:
|
||||
position_index = int(_position)
|
||||
except TypeError:
|
||||
_LOGGER.info(f"Position {_position} could not be converted to integer.")
|
||||
return
|
||||
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
delete_rule_hash = hash(rule_to_delete)
|
||||
|
||||
for index in range(0, len(self._acl)):
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> int:
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
def get_relevant_rules(
|
||||
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all ACL rules that relate to the given arguments.
|
||||
|
||||
:param _source_ip_address: the source IP address to check
|
||||
:param _dest_ip_address: the destination IP address to check
|
||||
:param _protocol: the protocol to check
|
||||
:param _port: the port to check
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules = {}
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
|
||||
):
|
||||
# There's a matching rule.
|
||||
relevant_rules[self._acl.index(rule)] = rule
|
||||
|
||||
return relevant_rules
|
||||
87
src/primaite/acl/acl_rule.py
Normal file
87
src/primaite/acl/acl_rule.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
:param _permission: The permission (ALLOW or DENY)
|
||||
:param _source_ip: The source IP address
|
||||
:param _dest_ip: The destination IP address
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
return hash(
|
||||
(
|
||||
self.permission,
|
||||
self.source_ip,
|
||||
self.dest_ip,
|
||||
self.protocol,
|
||||
self.port,
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self) -> str:
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
"""
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self) -> str:
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
"""
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self) -> str:
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
"""
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
"""
|
||||
return self.port
|
||||
2
src/primaite/agents/__init__.py
Normal file
2
src/primaite/agents/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Common interface between RL agents from different libraries and PrimAITE."""
|
||||
319
src/primaite/agents/agent_abc.py
Normal file
319
src/primaite/agents/agent_abc.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
"""
|
||||
An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
If training configuration and laydown configuration are provided with a session path,
|
||||
the session path will be used.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
# initialise variables
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
|
||||
# convert session to path
|
||||
if session_path is not None:
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
# load session
|
||||
self.load(session_path)
|
||||
else:
|
||||
# set training config path
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(
|
||||
self._training_config_path, legacy_file=legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = str(uuid4())
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
"""The session timestamp as a string."""
|
||||
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@property
|
||||
def learning_path(self) -> Path:
|
||||
"""The learning outputs path."""
|
||||
path = self.session_path / "learning"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def evaluation_path(self) -> Path:
|
||||
"""The evaluation outputs path."""
|
||||
path = self.session_path / "evaluation"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def checkpoints_path(self) -> Path:
|
||||
"""The Session checkpoints path."""
|
||||
path = self.learning_path / "checkpoints"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self) -> str:
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_path`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self) -> None:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
_LOGGER.info("Finished learning")
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_evaluate:
|
||||
self._update_session_metadata_file()
|
||||
self.is_eval = True
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load an agent from file."""
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = laydown_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
|
||||
if learning_session:
|
||||
title += "(Learning)"
|
||||
path = self.learning_path / csv_file
|
||||
image_path = self.learning_path / image_file
|
||||
else:
|
||||
title += "(Evaluation)"
|
||||
path = self.evaluation_path / csv_file
|
||||
image_path = self.evaluation_path / image_file
|
||||
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
118
src/primaite/agents/hardcoded_abc.py
Normal file
118
src/primaite/agents/hardcoded_abc.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path] = None) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
515
src/primaite/agents/hardcoded_acl.py
Normal file
515
src/primaite/agents/hardcoded_acl.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
transform_action_acl_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
# full view action using observation space, action
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, IER]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
:type green_iers: Dict[str, IER]
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
|
||||
:rtype: Dict[str, IER]
|
||||
"""
|
||||
blocked_green_iers = {}
|
||||
|
||||
for green_ier_id, green_ier in green_iers.items():
|
||||
source_node_id = green_ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = green_ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = green_ier.get_protocol() # e.g. 'TCP'
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
(No ALLOW rule, therefore blocked).
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
blocked_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
blocked_rules[rule_key] = rule_value
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_matching_acl_rules(
|
||||
self,
|
||||
source_node_id: str,
|
||||
dest_node_id: str,
|
||||
protocol: str,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
:param source_node_id: Source node
|
||||
:type source_node_id: str
|
||||
:param dest_node_id: Destination nodes
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: str
|
||||
:param port: Network port
|
||||
:type port: str
|
||||
:param acl: Access Control list which will be filtered
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The environment's node directory.
|
||||
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
|
||||
:param services_list: List of services registered for the environment.
|
||||
:type services_list: List[str]
|
||||
:return: Filtered version of 'acl'
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
else:
|
||||
source_node_address = source_node_id
|
||||
|
||||
if dest_node_id != "ANY":
|
||||
dest_node_address = nodes[str(dest_node_id)].ip_address
|
||||
else:
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
# TODO: This should throw an error because protocol is a string
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
self,
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_deny_acl_rules(
|
||||
self,
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
- Which ACL rules already exist, - otherwise:
|
||||
- The agent would perminently get stuck in a loop of performing the same action over and over.
|
||||
(best action is to block something, but its already blocked but doesn't know this)
|
||||
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
|
||||
if it doesnt know what rules exist)
|
||||
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
|
||||
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
|
||||
based on the observation space. Additionally, potentially in the future, once a node state
|
||||
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
|
||||
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
|
||||
|
||||
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
|
||||
|
||||
If a service node becomes compromised there's a decision to make - do we block that service?
|
||||
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
|
||||
Cons: Will block a green IER, decreasing the reward
|
||||
We decide to block the service.
|
||||
|
||||
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
# obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, _, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# 1. Check if node is compromised. If so we want to block its outwards services
|
||||
# a. If it is comprimised check if there's an allow rule we should delete.
|
||||
# cons: might delete a multi-rule from any source node (ANY -> x)
|
||||
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
|
||||
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
|
||||
found_action = False
|
||||
for service_num, service_states in enumerate(s):
|
||||
for x, service_state in enumerate(service_states):
|
||||
if service_state == "COMPROMISED":
|
||||
action_source_id = x + 1 # +1 as 0 is any
|
||||
action_destination_id = "ANY"
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
action_port = "ANY"
|
||||
|
||||
allow_rules = self.get_allow_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
deny_rules = self.get_deny_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
if len(allow_rules) > 0:
|
||||
# Check if there's an allow rule we should delete
|
||||
rule = list(allow_rules.values())[0]
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
elif len(deny_rules) > 0:
|
||||
# TODO OPTIONAL
|
||||
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
|
||||
# to create another
|
||||
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
|
||||
continue
|
||||
else:
|
||||
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
|
||||
action_decision = "CREATE"
|
||||
action_permission = "DENY"
|
||||
break
|
||||
if found_action:
|
||||
break
|
||||
|
||||
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
|
||||
# add an Allow rule if the green IER is being blocked.
|
||||
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
|
||||
# If there's a DENY rule delete it if:
|
||||
# - There isn't already a deny rule
|
||||
# - It doesnt allows a comprimised node to become operational.
|
||||
# b. Add an ALLOW rule if:
|
||||
# - There isn't already an allow rule
|
||||
# - It doesnt allows a comprimised node to become operational
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
|
||||
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
|
||||
action_decision = "CREATE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_id = int(ier.get_source_node_id())
|
||||
action_destination_id = int(ier.get_dest_node_id())
|
||||
action_protocol_name = ier.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = ier.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
|
||||
if found_action:
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
else:
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
of previous actions taken and NO reward feedback.
|
||||
|
||||
We rely on randomness to select the precise action, as we want to
|
||||
block all traffic originating from a compromised node, without being
|
||||
able to tell:
|
||||
1. Which ACL rules already exist
|
||||
2. Which actions the agent has already tried.
|
||||
|
||||
There is a high probability that the correct rule will not be deleted
|
||||
before the state becomes overwhelmed.
|
||||
|
||||
Currently, a deny rule does not overwrite an allow rule. The allow
|
||||
rules must be deleted.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
# Bad assumption that number of protocols equals number of ports
|
||||
# AND no rules exist with an ANY port
|
||||
action_port = np.random.choice(list(range(1, len(s) + 1)))
|
||||
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_ip,
|
||||
action_destination_ip,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, action_dict)
|
||||
return nothing_action
|
||||
125
src/primaite/agents/hardcoded_node.py
Normal file
125
src/primaite/agents/hardcoded_node.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, os, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# Check in order of most important states (order doesn't currently
|
||||
# matter, but it probably should)
|
||||
# First see if any OS states are compromised
|
||||
for x, os_state in enumerate(os):
|
||||
if os_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OS"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = 0 # does nothing isn't relevant for os
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, see if any Services are compromised
|
||||
# We fix the compromised state before overwhelemd state,
|
||||
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, See if any services are overwhelmed
|
||||
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
|
||||
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
|
||||
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "OVERWHELMED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Finally, turn on any off nodes
|
||||
for x, operating_state in enumerate(o):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good actions, just go with an action that wont do any harm
|
||||
action_node_id = 1
|
||||
action_node_property = "NONE"
|
||||
property_action = "ON"
|
||||
action_service_index = 0
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
|
||||
return action
|
||||
286
src/primaite/agents/rllib.py
Normal file
286
src/primaite/agents/rllib.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.exceptions import RLlibAgentError
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: verify type of env_config
|
||||
def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
session_path=env_config["session_path"],
|
||||
timestamp_str=env_config["timestamp_str"],
|
||||
)
|
||||
|
||||
|
||||
# TODO: verify type hint return type
|
||||
def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
return logger_creator
|
||||
|
||||
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB")
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
# TODO: implement RLlib agent loading
|
||||
if session_path is not None:
|
||||
msg = "RLlib agent loading has not been implemented yet"
|
||||
_LOGGER.critical(msg)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if self._training_config.session_type == SessionType.EVAL:
|
||||
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
|
||||
_LOGGER.critical(msg)
|
||||
raise RLlibAgentError(msg)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: Union[PPOConfig, A2CConfig]
|
||||
|
||||
self._current_result: dict
|
||||
self._setup()
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}, "
|
||||
f"deep_learning_framework="
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
self._train_agent = None # Required to capture the learning agent to close after eval
|
||||
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._agent_config = self._agent_config_class()
|
||||
|
||||
self._agent_config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
),
|
||||
)
|
||||
self._agent_config.seed = self._training_config.seed
|
||||
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_train_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self.save()
|
||||
super().learn()
|
||||
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
|
||||
if self._training_config.session_type is not SessionType.TRAIN:
|
||||
self._train_agent = self._agent
|
||||
else:
|
||||
self._agent.stop()
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def _unpack_saved_agent_into_eval(self) -> Path:
|
||||
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
|
||||
agent_restore_path = self.evaluation_path / "agent_restore"
|
||||
if agent_restore_path.exists():
|
||||
shutil.rmtree(agent_restore_path)
|
||||
agent_restore_path.mkdir()
|
||||
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
|
||||
zip_file.extractall(agent_restore_path)
|
||||
return agent_restore_path
|
||||
|
||||
def _setup_eval(self):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
self._setup_eval()
|
||||
|
||||
self._env: Primaite = Primaite(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
|
||||
)
|
||||
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
for step in range(time_steps):
|
||||
action = self._agent.compute_single_action(observation=obs, explore=False)
|
||||
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
# Now we're safe to close the learning agent and write the mean rewards per episode for it
|
||||
if self._training_config.session_type is not SessionType.TRAIN:
|
||||
self._train_agent.stop()
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
# Perform a clean-up of the unpacked agent
|
||||
if (self.evaluation_path / "agent_restore").exists():
|
||||
shutil.rmtree((self.evaluation_path / "agent_restore"))
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, overwrite_existing: bool = True) -> None:
|
||||
"""Save the agent."""
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
temp_dir.mkdir()
|
||||
|
||||
# Save the agent to the temp dir
|
||||
self._agent.save(str(temp_dir))
|
||||
|
||||
# Capture the saved Rllib checkpoint inside the temp directory
|
||||
for file in temp_dir.iterdir():
|
||||
checkpoint_dir = file
|
||||
break
|
||||
|
||||
# Zip the folder
|
||||
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
206
src/primaite/agents/sb3.py
Normal file
206
src/primaite/agents/sb3.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_class: Union[PPO, A2C]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}"
|
||||
)
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
legacy_training_config=self.legacy_training_config,
|
||||
legacy_lay_down_config=self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
# check if there is a zip file that needs to be loaded
|
||||
load_file = next(self.session_path.rglob("*.zip"), None)
|
||||
|
||||
if not load_file:
|
||||
# create a new env and agent
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
else:
|
||||
# set env values from session metadata
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# load environment values
|
||||
if self.is_eval:
|
||||
# evaluation always starts at 0
|
||||
self._env.episode_count = 0
|
||||
self._env.total_step_count = 0
|
||||
else:
|
||||
# carry on from previous learning sessions
|
||||
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
||||
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
||||
|
||||
# load the file
|
||||
self._agent = self._agent_class.load(load_file, env=self._env)
|
||||
|
||||
# set agent values
|
||||
self._agent.verbose = self.sb3_output_verbose_level
|
||||
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env._write_av_reward_per_episode() # noqa
|
||||
self.save()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self._env._write_av_reward_per_episode() # noqa
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
59
src/primaite/agents/simple.py
Normal file
59
src/primaite/agents/simple.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Random Agent.
|
||||
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
class DummyAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Dummy Agent.
|
||||
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing ACL agent.
|
||||
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
return nothing_action
|
||||
|
||||
|
||||
class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing Node agent.
|
||||
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
# nothing_action should currently always be 0
|
||||
|
||||
return nothing_action
|
||||
450
src/primaite/agents/utils.py
Normal file
450
src/primaite/agents/utils.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
LinkStatus,
|
||||
NodeHardwareAction,
|
||||
NodePOLType,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
)
|
||||
|
||||
|
||||
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def is_valid_node_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
# print("node property", node_property, "\nnode action", node_action)
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action: List[int]) -> bool:
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
os_states = [SoftwareState(i).name for i in obs[:, 2]]
|
||||
new_obs = [ids, operating_states, os_states]
|
||||
|
||||
for service in range(4, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform observation to readable format.
|
||||
|
||||
example
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
# Convert list of tuples to list of lists
|
||||
new_obs = [list(i) for i in new_obs]
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
:param obs: observation in the 'old' (NodeLinkTable) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:return: reformatted observation
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
|
||||
"""Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
example:
|
||||
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
|
||||
|
||||
new_obs = array([[ 1, 1, 1, 1],
|
||||
[ 2, 1, 1, 1],
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
|
||||
:param obs: observation in the 'new' (MultiDiscrete) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: number of links in the network, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: number of services on the network, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: 2-d BOX observation space, in the same format as NodeLinkTable
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
|
||||
# Add empty links back and add node ID back
|
||||
s = np.zeros(
|
||||
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
|
||||
dtype=np.int64,
|
||||
)
|
||||
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
|
||||
s[:num_nodes, 1:] = reshaped_nodes # put values back in
|
||||
new_obs = s
|
||||
|
||||
# Add links back in
|
||||
links = obs[-num_links:]
|
||||
# Links will be added to the last protocol/service slot but they are not specific to that service
|
||||
new_obs[num_nodes:, -1] = links
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(
|
||||
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
|
||||
) -> str:
|
||||
"""Build a string describing the difference between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
:param obs1: First observation
|
||||
:type obs1: np.ndarray
|
||||
:param obs2: Second observation
|
||||
:type obs2: np.ndarray
|
||||
:param num_nodes: How many nodes are in the network laydown, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: How many links are in the network laydown, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: How many services are configured for this scenario, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: A multi-line string with a human-readable description of the difference.
|
||||
:rtype: str
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
list_of_changes = []
|
||||
for n, row in enumerate(obs1 - obs2):
|
||||
if row.any() != 0:
|
||||
relevant_changes = np.where(row != 0, obs2[n], -1)
|
||||
relevant_changes[0] = obs2[n, 0] # ID is always relevant
|
||||
is_link = relevant_changes[0] > num_nodes
|
||||
desc = _describe_obs_change_helper(relevant_changes, is_link)
|
||||
list_of_changes.append(desc)
|
||||
|
||||
change_string = "\n ".join(list_of_changes)
|
||||
if len(list_of_changes) > 0:
|
||||
change_string = "\n " + change_string
|
||||
return change_string
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str:
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
|
||||
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
|
||||
|
||||
:param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one
|
||||
row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new
|
||||
status where it has changed.
|
||||
:type obs_change: List[int]
|
||||
:param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node.
|
||||
:type is_link: bool
|
||||
:return: A human-readable description of the difference between the two observation rows.
|
||||
:rtype: str
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
if is_link
|
||||
else HardwareState(obs_change[i]).name
|
||||
if i == 1
|
||||
else SoftwareState(obs_change[i]).name
|
||||
for i in index_changed
|
||||
]
|
||||
|
||||
if not is_link:
|
||||
desc = f"ID {obs_change[0]}:"
|
||||
for node_pol_type, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + node_pol_type + " changed to " + state + "."
|
||||
else:
|
||||
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]:
|
||||
"""Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
:param action: Action in 'readable' format
|
||||
:type action: List[Union[str,int]]
|
||||
:return: Action with verbs encoded as ints
|
||||
:rtype: List[int]
|
||||
"""
|
||||
action_node_id = action[0]
|
||||
action_node_property = NodePOLType[action[1]].value
|
||||
|
||||
if action[1] == "OPERATING":
|
||||
property_action = NodeHardwareAction[action[2]].value
|
||||
elif action[1] == "OS" or action[1] == "SERVICE":
|
||||
property_action = NodeSoftwareAction[action[2]].value
|
||||
else:
|
||||
property_action = 0
|
||||
|
||||
action_service_index = action[3]
|
||||
|
||||
new_action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray:
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
:param action: ACL-based action expressed as a list of human-readable ints and strings
|
||||
:type action: List[Union[int,str]]
|
||||
:return: The same action but encoded to contain only integers.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
|
||||
action_permissions = {"DENY": 0, "ALLOW": 1}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == "ANY":
|
||||
new_action[n + 2] = 0
|
||||
|
||||
new_action = np.array(new_action)
|
||||
return new_action
|
||||
|
||||
|
||||
def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str:
|
||||
"""Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
|
||||
:param ip: The IP address of the node whose ID is required
|
||||
:type ip: str
|
||||
:param node_dict: The environment's node registry dictionary
|
||||
:type node_dict: Dict[str,NodeUnion]
|
||||
:return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip`
|
||||
:rtype: str
|
||||
"""
|
||||
for node_key, node_value in node_dict.items():
|
||||
node_ip = node_value.ip_address
|
||||
if node_ip == ip:
|
||||
return node_key
|
||||
|
||||
|
||||
def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int:
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
:param old_action: Action expressed as a list of choices, eg. [1,1,1,0]
|
||||
:type old_action: np.ndarray
|
||||
:param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions.
|
||||
:type action_dict: Dict[int,List]
|
||||
:return: Action key correspoinding to the input `old_action`
|
||||
:rtype: int
|
||||
"""
|
||||
for key, val in action_dict.items():
|
||||
if list(val) == list(old_action):
|
||||
return key
|
||||
# Not all possible actions are included in dict, only valid action are
|
||||
# if action is not in the dict, its an invalid action so return 0
|
||||
return 0
|
||||
213
src/primaite/cli.py
Normal file
213
src/primaite/cli.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.data_viz import PlotlyTemplate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def build_dirs() -> None:
|
||||
"""Build the PrimAITE app directories."""
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
PRIMAITE_PATHS.mkdirs()
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True) -> None:
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
:param overwrite: If True, will overwrite existing demo notebooks.
|
||||
"""
|
||||
from primaite.setup import reset_demo_notebooks
|
||||
|
||||
reset_demo_notebooks.run(overwrite)
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]) -> None:
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
import re
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
if os.path.isfile(PRIMAITE_PATHS.app_log_file_path):
|
||||
with open(PRIMAITE_PATHS.app_log_file_path) as file:
|
||||
lines = file.readlines()
|
||||
for line in lines[-last_n:]:
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
|
||||
|
||||
@app.command()
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the PrimAITE Log Level.
|
||||
|
||||
To View, simply call: primaite log-level
|
||||
|
||||
To set, call: primaite log-level <desired log level>
|
||||
|
||||
For example, to set the to debug, call: primaite log-level DEBUG
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if level:
|
||||
primaite_config["logging"]["log_level"] = level.value
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
else:
|
||||
level = primaite_config["logging"]["log_level"]
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
start_jupyter_session()
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
import primaite
|
||||
|
||||
print(primaite.__version__)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
old_installation_clean_up.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from primaite import getLogger
|
||||
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_LOGGER.info("Performing the PrimAITE first-time setup...")
|
||||
|
||||
_LOGGER.info("Building primaite_config.yaml...")
|
||||
|
||||
_LOGGER.info("Building the PrimAITE app directories...")
|
||||
PRIMAITE_PATHS.mkdirs()
|
||||
|
||||
_LOGGER.info("Rebuilding the demo notebooks...")
|
||||
reset_demo_notebooks.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
|
||||
old_installation_clean_up.run()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
tc: Optional[str] = None,
|
||||
ldc: Optional[str] = None,
|
||||
load: Optional[str] = None,
|
||||
legacy_tc: bool = False,
|
||||
legacy_ldc: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
|
||||
load: The directory of a previous session. Optional. If no value is passed, then the session
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
|
||||
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
|
||||
|
||||
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if load is not None:
|
||||
# run a loaded session
|
||||
run(session_path=load)
|
||||
|
||||
else:
|
||||
# start a new session using tc and ldc
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(
|
||||
training_config_path=tc,
|
||||
lay_down_config_path=ldc,
|
||||
legacy_training_config=legacy_tc,
|
||||
legacy_lay_down_config=legacy_ldc,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
To set, call: primaite plotly-template <desired template>
|
||||
|
||||
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"]["template"]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
2
src/primaite/common/__init__.py
Normal file
2
src/primaite/common/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Objects which are shared between many PrimAITE modules."""
|
||||
8
src/primaite/common/custom_typing.py
Normal file
8
src/primaite/common/custom_typing.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Union
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
208
src/primaite/common/enums.py
Normal file
208
src/primaite/common/enums.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""Node type enumeration."""
|
||||
|
||||
CCTV = 1
|
||||
SWITCH = 2
|
||||
COMPUTER = 3
|
||||
LINK = 4
|
||||
MONITOR = 5
|
||||
PRINTER = 6
|
||||
LOP = 7
|
||||
RTU = 8
|
||||
ACTUATOR = 9
|
||||
SERVER = 10
|
||||
|
||||
|
||||
class Priority(Enum):
|
||||
"""Node priority enumeration."""
|
||||
|
||||
P1 = 1
|
||||
P2 = 2
|
||||
P3 = 3
|
||||
P4 = 4
|
||||
P5 = 5
|
||||
|
||||
|
||||
class HardwareState(Enum):
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
SHUTTING_DOWN = 4
|
||||
BOOTING = 5
|
||||
|
||||
|
||||
class SoftwareState(Enum):
|
||||
"""Software or Service state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
OVERWHELMED = 4
|
||||
|
||||
|
||||
class NodePOLType(Enum):
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
NONE = 0
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
FILE = 4
|
||||
|
||||
|
||||
class NodePOLInitiator(Enum):
|
||||
"""Node Pattern of Life initiator enumeration."""
|
||||
|
||||
DIRECT = 1
|
||||
IER = 2
|
||||
SERVICE = 3
|
||||
|
||||
|
||||
class Protocol(Enum):
|
||||
"""Service protocol enumeration."""
|
||||
|
||||
LDAP = 0
|
||||
FTP = 1
|
||||
HTTPS = 2
|
||||
SMTP = 3
|
||||
RTP = 4
|
||||
IPP = 5
|
||||
TCP = 6
|
||||
NONE = 7
|
||||
|
||||
|
||||
class SessionType(Enum):
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
|
||||
TRAIN = 1
|
||||
"Train an agent"
|
||||
EVAL = 2
|
||||
"Evaluate an agent"
|
||||
TRAIN_EVAL = 3
|
||||
"Train then evaluate an agent"
|
||||
|
||||
|
||||
class AgentFramework(Enum):
|
||||
"""The agent algorithm framework/package."""
|
||||
|
||||
CUSTOM = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
RLLIB = 2
|
||||
"Ray RLlib"
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework."""
|
||||
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
"Tensorflow 2.x"
|
||||
TORCH = "torch"
|
||||
"PyTorch"
|
||||
|
||||
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"The Hardcoded agents"
|
||||
DO_NOTHING = 4
|
||||
"The DoNothing agents"
|
||||
RANDOM = 5
|
||||
"The RandomAgent"
|
||||
DUMMY = 6
|
||||
"The DummyAgent"
|
||||
|
||||
|
||||
class HardCodedAgentView(Enum):
|
||||
"""The view the deterministic hard-coded agent has of the environment."""
|
||||
|
||||
BASIC = 1
|
||||
"The current observation space only"
|
||||
FULL = 2
|
||||
"Full environment view with actions taken and reward feedback"
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
|
||||
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
BOX = 0
|
||||
MULTIDISCRETE = 1
|
||||
|
||||
|
||||
class FileSystemState(Enum):
|
||||
"""File System State."""
|
||||
|
||||
GOOD = 1
|
||||
CORRUPT = 2
|
||||
DESTROYED = 3
|
||||
REPAIRING = 4
|
||||
RESTORING = 5
|
||||
|
||||
|
||||
class NodeHardwareAction(Enum):
|
||||
"""Node hardware action."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESET = 3
|
||||
|
||||
|
||||
class NodeSoftwareAction(Enum):
|
||||
"""Node software action."""
|
||||
|
||||
NONE = 0
|
||||
PATCHING = 1
|
||||
|
||||
|
||||
class LinkStatus(Enum):
|
||||
"""Link traffic status."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
|
||||
class SB3OutputVerboseLevel(IntEnum):
|
||||
"""The Stable Baselines3 learn/eval output verbosity level."""
|
||||
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Any firewall rule type."""
|
||||
|
||||
NONE = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
47
src/primaite/common/protocol.py
Normal file
47
src/primaite/common/protocol.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name: str) -> None:
|
||||
"""
|
||||
Initialise a protocol.
|
||||
|
||||
:param _name: The name of the protocol
|
||||
:type _name: str
|
||||
"""
|
||||
self.name: str = _name
|
||||
self.load: int = 0 # bps
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Gets the protocol name.
|
||||
|
||||
Returns:
|
||||
The protocol name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets the protocol load.
|
||||
|
||||
Returns:
|
||||
The protocol load (bps)
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load: int) -> None:
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
|
||||
Args:
|
||||
_load: The load to add
|
||||
"""
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self) -> None:
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
28
src/primaite/common/service.py
Normal file
28
src/primaite/common/service.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SoftwareState
|
||||
|
||||
|
||||
class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Initialise a service.
|
||||
|
||||
:param name: The service name.
|
||||
:param port: The service port.
|
||||
:param software_state: The service SoftwareState.
|
||||
"""
|
||||
self.name: str = name
|
||||
self.port: str = port
|
||||
self.software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
|
||||
def reduce_patching_count(self) -> None:
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self.software_state = SoftwareState.GOOD
|
||||
2
src/primaite/config/__init__.py
Normal file
2
src/primaite/config/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Configuration parameters for running experiments."""
|
||||
@@ -0,0 +1,166 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.5
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: SWITCH2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.6
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH3
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.7
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: LINK
|
||||
id: '7'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '4'
|
||||
- item_type: LINK
|
||||
id: '8'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '2'
|
||||
- item_type: LINK
|
||||
id: '9'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '10'
|
||||
name: link4
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: link5
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: link6
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '3'
|
||||
- item_type: GREEN_IER
|
||||
id: '13'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '3'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 50
|
||||
end_step: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '15'
|
||||
start_step: 60
|
||||
end_step: 100
|
||||
load: 1000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '2'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '16'
|
||||
start_step: 80
|
||||
end_step: 80
|
||||
targetNodeId: '2'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: ACL_RULE
|
||||
id: '17'
|
||||
permission: ALLOW
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
@@ -0,0 +1,366 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: PC3
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.13
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: PC4
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.20.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: IDS
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '7'
|
||||
name: SWITCH2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '8'
|
||||
name: LOP1
|
||||
node_class: SERVICE
|
||||
node_type: LOP
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '9'
|
||||
name: SERVER1
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '10'
|
||||
name: SERVER2
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.20.15
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '13'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '14'
|
||||
name: link4
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '15'
|
||||
name: link5
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '16'
|
||||
name: link6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '8'
|
||||
- item_type: LINK
|
||||
id: '17'
|
||||
name: link7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '18'
|
||||
name: link8
|
||||
bandwidth: 1000000000
|
||||
source: '8'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '19'
|
||||
name: link9
|
||||
bandwidth: 1000000000
|
||||
source: '7'
|
||||
destination: '9'
|
||||
- item_type: LINK
|
||||
id: '20'
|
||||
name: link10
|
||||
bandwidth: 1000000000
|
||||
source: '7'
|
||||
destination: '10'
|
||||
- item_type: GREEN_IER
|
||||
id: '21'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '22'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '23'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '3'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '24'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '4'
|
||||
destination: '10'
|
||||
mission_criticality: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '25'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '26'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '27'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.13
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '28'
|
||||
permission: ALLOW
|
||||
source: 192.168.20.14
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '29'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.14
|
||||
destination: 192.168.10.13
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '30'
|
||||
permission: DENY
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '31'
|
||||
permission: DENY
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '32'
|
||||
permission: DENY
|
||||
source: 192.168.10.13
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: DENY
|
||||
source: 192.168.20.14
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 8
|
||||
- item_type: RED_POL
|
||||
id: '34'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_POL
|
||||
id: '35'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '2'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '36'
|
||||
start_step: 30
|
||||
end_step: 128
|
||||
load: 440000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_IER
|
||||
id: '37'
|
||||
start_step: 30
|
||||
end_step: 128
|
||||
load: 440000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '38'
|
||||
start_step: 30
|
||||
end_step: 30
|
||||
targetNodeId: '9'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
@@ -0,0 +1,164 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SERVER1
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '5'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '6'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '7'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- item_type: GREEN_IER
|
||||
id: '8'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '4'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '9'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '4'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '10'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '4'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '11'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.2
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '12'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.3
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '13'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.4
|
||||
destination: 192.168.1.3
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '15'
|
||||
start_step: 30
|
||||
end_step: 256
|
||||
load: 10000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '4'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '16'
|
||||
start_step: 40
|
||||
end_step: 40
|
||||
targetNodeId: '4'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
@@ -0,0 +1,546 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: CLIENT_1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: CLIENT_2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH_1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SECURITY_SUITE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH_2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '7'
|
||||
name: WEB_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '8'
|
||||
name: DATABASE_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '9'
|
||||
name: BACKUP_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.16
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: LINK_2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: LINK_3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- item_type: LINK
|
||||
id: '13'
|
||||
name: LINK_4
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '14'
|
||||
name: LINK_5
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '15'
|
||||
name: LINK_6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '16'
|
||||
name: LINK_7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '17'
|
||||
name: LINK_8
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '8'
|
||||
- item_type: LINK
|
||||
id: '18'
|
||||
name: LINK_9
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '9'
|
||||
- item_type: GREEN_IER
|
||||
id: '19'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '20'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '1'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '21'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '22'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '23'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '7'
|
||||
destination: '8'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '24'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 100000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '8'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '25'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '26'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '27'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '7'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '28'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '29'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '8'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '30'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '8'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '31'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '9'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '32'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 8
|
||||
- item_type: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 9
|
||||
- item_type: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 10
|
||||
- item_type: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 11
|
||||
- item_type: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 12
|
||||
- item_type: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 13
|
||||
- item_type: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 14
|
||||
- item_type: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.16
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 15
|
||||
- item_type: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 16
|
||||
- item_type: RED_POL
|
||||
id: '50'
|
||||
start_step: 50
|
||||
end_step: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '51'
|
||||
start_step: 75
|
||||
end_step: 105
|
||||
load: 10000
|
||||
protocol: UDP
|
||||
port: '53'
|
||||
source: '1'
|
||||
destination: '8'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '52'
|
||||
start_step: 100
|
||||
end_step: 100
|
||||
targetNodeId: '8'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_POL
|
||||
id: '53'
|
||||
start_step: 105
|
||||
end_step: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: FILE
|
||||
protocol: NA
|
||||
state: CORRUPT
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- item_type: RED_POL
|
||||
id: '54'
|
||||
start_step: 105
|
||||
end_step: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP_SQL
|
||||
state: COMPROMISED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- item_type: RED_POL
|
||||
id: '55'
|
||||
start_step: 125
|
||||
end_step: 125
|
||||
targetNodeId: '7'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
@@ -0,0 +1,168 @@
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 1
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
141
src/primaite/config/lay_down_config.py
Normal file
141
src/primaite/config/lay_down_config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a legacy lay down config to the new format.
|
||||
|
||||
:param legacy_config: A legacy lay down config.
|
||||
"""
|
||||
field_conversion_map = {
|
||||
"itemType": "item_type",
|
||||
"portsList": "ports_list",
|
||||
"serviceList": "service_list",
|
||||
"baseType": "node_class",
|
||||
"nodeType": "node_type",
|
||||
"hardwareState": "hardware_state",
|
||||
"softwareState": "software_state",
|
||||
"startStep": "start_step",
|
||||
"endStep": "end_step",
|
||||
"fileSystemState": "file_system_state",
|
||||
"ipAddress": "ip_address",
|
||||
"missionCriticality": "mission_criticality",
|
||||
}
|
||||
new_config = []
|
||||
for item in legacy_config:
|
||||
if "itemType" in item:
|
||||
if item["itemType"] in ["ACTIONS", "STEPS"]:
|
||||
continue
|
||||
new_dict = {}
|
||||
for key in item.keys():
|
||||
conversion_key = field_conversion_map.get(key)
|
||||
if key == "id" and "itemType" in item:
|
||||
if item["itemType"] == "NODE":
|
||||
conversion_key = "node_id"
|
||||
if conversion_key:
|
||||
new_dict[conversion_key] = item[key]
|
||||
else:
|
||||
new_dict[key] = item[key]
|
||||
new_config.append(new_dict)
|
||||
return new_config
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise False.
|
||||
:return: The lay down config as a dict.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading lay down config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_lay_down_config(config)
|
||||
except KeyError:
|
||||
msg = (
|
||||
f"Failed to convert lay down config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
return config
|
||||
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def ddos_basic_two_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def dos_very_basic_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
438
src/primaite/config/training_config.py
Normal file
438
src/primaite/config/training_config.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
RulePermissionType,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training"
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The AgentFramework"
|
||||
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework"
|
||||
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
|
||||
"The AgentIdentifier"
|
||||
|
||||
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
|
||||
"The view the deterministic hard-coded agent has of the environment"
|
||||
|
||||
random_red_agent: bool = False
|
||||
"Creates Random Red Agent Attacks"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use"
|
||||
|
||||
num_train_episodes: int = 10
|
||||
"The number of episodes to train over during an training session"
|
||||
|
||||
num_train_steps: int = 256
|
||||
"The number of steps in an episode during an training session"
|
||||
|
||||
num_eval_episodes: int = 1
|
||||
"The number of episodes to train over during an evaluation session"
|
||||
|
||||
num_eval_steps: int = 256
|
||||
"The number of steps in an episode during an evaluation session"
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
|
||||
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
|
||||
"The observation space config dict"
|
||||
|
||||
time_delay: int = 10
|
||||
"The delay between steps (ms). Applies to generic agents only"
|
||||
|
||||
# file
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: bool = False
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
"File path and file name of agent if you're loading one in"
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int = 1000000000
|
||||
"The high value for the observation space"
|
||||
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: float = 0
|
||||
|
||||
# Node Hardware State
|
||||
off_should_be_on: float = -0.001
|
||||
off_should_be_resetting: float = -0.0005
|
||||
on_should_be_off: float = -0.0002
|
||||
on_should_be_resetting: float = -0.0005
|
||||
resetting_should_be_on: float = -0.0005
|
||||
resetting_should_be_off: float = -0.0002
|
||||
resetting: float = -0.0003
|
||||
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: float = 0.0002
|
||||
good_should_be_compromised: float = 0.0005
|
||||
good_should_be_overwhelmed: float = 0.0005
|
||||
patching_should_be_good: float = -0.0005
|
||||
patching_should_be_compromised: float = 0.0002
|
||||
patching_should_be_overwhelmed: float = 0.0002
|
||||
patching: float = -0.0003
|
||||
compromised_should_be_good: float = -0.002
|
||||
compromised_should_be_patching: float = -0.002
|
||||
compromised_should_be_overwhelmed: float = -0.002
|
||||
compromised: float = -0.002
|
||||
overwhelmed_should_be_good: float = -0.002
|
||||
overwhelmed_should_be_patching: float = -0.002
|
||||
overwhelmed_should_be_compromised: float = -0.002
|
||||
overwhelmed: float = -0.002
|
||||
|
||||
# Node File System State
|
||||
good_should_be_repairing: float = 0.0002
|
||||
good_should_be_restoring: float = 0.0002
|
||||
good_should_be_corrupt: float = 0.0005
|
||||
good_should_be_destroyed: float = 0.001
|
||||
repairing_should_be_good: float = -0.0005
|
||||
repairing_should_be_restoring: float = 0.0002
|
||||
repairing_should_be_corrupt: float = 0.0002
|
||||
repairing_should_be_destroyed: float = 0.0000
|
||||
repairing: float = -0.0003
|
||||
restoring_should_be_good: float = -0.001
|
||||
restoring_should_be_repairing: float = -0.0002
|
||||
restoring_should_be_corrupt: float = 0.0001
|
||||
restoring_should_be_destroyed: float = 0.0002
|
||||
restoring: float = -0.0006
|
||||
corrupt_should_be_good: float = -0.001
|
||||
corrupt_should_be_repairing: float = -0.001
|
||||
corrupt_should_be_restoring: float = -0.001
|
||||
corrupt_should_be_destroyed: float = 0.0002
|
||||
corrupt: float = -0.001
|
||||
destroyed_should_be_good: float = -0.002
|
||||
destroyed_should_be_repairing: float = -0.002
|
||||
destroyed_should_be_restoring: float = -0.002
|
||||
destroyed_should_be_corrupt: float = -0.002
|
||||
destroyed: float = -0.002
|
||||
scanning: float = -0.0002
|
||||
|
||||
# IER status
|
||||
red_ier_running: float = -0.0005
|
||||
green_ier_blocked: float = -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
"The time taken to patch the OS"
|
||||
|
||||
node_reset_duration: int = 5
|
||||
"The time taken to reset a node (hardware)"
|
||||
|
||||
node_booting_duration: int = 3
|
||||
"The Time taken to turn on the node"
|
||||
|
||||
node_shutdown_duration: int = 2
|
||||
"The time taken to turn off the node"
|
||||
|
||||
service_patching_duration: int = 5
|
||||
"The time taken to patch a service"
|
||||
|
||||
file_system_repairing_limit: int = 5
|
||||
"The time take to repair the file system"
|
||||
|
||||
file_system_restoring_limit: int = 5
|
||||
"The time take to restore the file system"
|
||||
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system"
|
||||
|
||||
deterministic: bool = False
|
||||
"If true, the training will be deterministic"
|
||||
|
||||
seed: Optional[int] = None
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
"""
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"agent_identifier": AgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
:param json_serializable: If True, Enums are converted to their
|
||||
string name.
|
||||
:return: The ``TrainingConfig`` as a dict.
|
||||
"""
|
||||
data = self.__dict__
|
||||
if json_serializable:
|
||||
data["agent_framework"] = self.agent_framework.name
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.name
|
||||
data["agent_identifier"] = self.agent_identifier.name
|
||||
data["action_type"] = self.action_type.name
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
data["implicit_acl_rule"] = self.implicit_acl_rule.name
|
||||
|
||||
return data
|
||||
|
||||
def __str__(self) -> str:
|
||||
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
|
||||
tc = f"{self.agent_framework}, "
|
||||
if self.agent_framework is AgentFramework.RLLIB:
|
||||
tc += f"{self.deep_learning_framework}, "
|
||||
tc += f"{self.agent_identifier}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
tc += f"{self.action_type}, "
|
||||
tc += f"observation_space={obs_str}, "
|
||||
if self.session_type is SessionType.TRAIN:
|
||||
tc += f"{self.num_train_episodes} episodes @ "
|
||||
tc += f"{self.num_train_steps} steps"
|
||||
elif self.session_type is SessionType.EVAL:
|
||||
tc += f"{self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
else:
|
||||
tc += f"Training: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
return tc
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:return: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
:raises TypeError: When the TrainingConfig object cannot be created
|
||||
using the values from the config file read from ``file_path``.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading training config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_training_config_dict(config)
|
||||
|
||||
except KeyError as e:
|
||||
msg = (
|
||||
f"Failed to convert training config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise e
|
||||
try:
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
|
||||
_LOGGER.critical(msg, exc_info=True)
|
||||
raise e
|
||||
msg = f"Cannot load the training config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_train_steps: int = 256,
|
||||
num_eval_steps: int = 256,
|
||||
num_train_episodes: int = 10,
|
||||
num_eval_episodes: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_train_steps: The number of train steps to set as legacy training configs
|
||||
don't have num_train_steps values.
|
||||
:param num_eval_steps: The number of eval steps to set as legacy training configs
|
||||
don't have num_eval_steps values.
|
||||
:param num_train_episodes: The number of train episodes to set as legacy training configs
|
||||
don't have num_train_episodes values.
|
||||
:param num_eval_episodes: The number of eval episodes to set as legacy training configs
|
||||
don't have num_eval_episodes values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"agent_framework": agent_framework.name,
|
||||
"agent_identifier": agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_train_steps": num_train_steps,
|
||||
"num_eval_steps": num_eval_steps,
|
||||
"num_train_episodes": num_train_episodes,
|
||||
"num_eval_episodes": num_eval_episodes,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
|
||||
}
|
||||
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
|
||||
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
config_dict[new_key] = value
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
:param legacy_key: A legacy training config key.
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_train_episodes",
|
||||
"numSteps": "num_train_steps",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
"sessionType": "session_type",
|
||||
"loadAgent": "load_agent",
|
||||
"agentLoadFile": "agent_load_file",
|
||||
"observationSpaceHighValue": "observation_space_high_value",
|
||||
"allOk": "all_ok",
|
||||
"offShouldBeOn": "off_should_be_on",
|
||||
"offShouldBeResetting": "off_should_be_resetting",
|
||||
"onShouldBeOff": "on_should_be_off",
|
||||
"onShouldBeResetting": "on_should_be_resetting",
|
||||
"resettingShouldBeOn": "resetting_should_be_on",
|
||||
"resettingShouldBeOff": "resetting_should_be_off",
|
||||
"resetting": "resetting",
|
||||
"goodShouldBePatching": "good_should_be_patching",
|
||||
"goodShouldBeCompromised": "good_should_be_compromised",
|
||||
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
|
||||
"patchingShouldBeGood": "patching_should_be_good",
|
||||
"patchingShouldBeCompromised": "patching_should_be_compromised",
|
||||
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
|
||||
"patching": "patching",
|
||||
"compromisedShouldBeGood": "compromised_should_be_good",
|
||||
"compromisedShouldBePatching": "compromised_should_be_patching",
|
||||
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
|
||||
"compromised": "compromised",
|
||||
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
|
||||
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
|
||||
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
|
||||
"overwhelmed": "overwhelmed",
|
||||
"goodShouldBeRepairing": "good_should_be_repairing",
|
||||
"goodShouldBeRestoring": "good_should_be_restoring",
|
||||
"goodShouldBeCorrupt": "good_should_be_corrupt",
|
||||
"goodShouldBeDestroyed": "good_should_be_destroyed",
|
||||
"repairingShouldBeGood": "repairing_should_be_good",
|
||||
"repairingShouldBeRestoring": "repairing_should_be_restoring",
|
||||
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
|
||||
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
|
||||
"repairing": "repairing",
|
||||
"restoringShouldBeGood": "restoring_should_be_good",
|
||||
"restoringShouldBeRepairing": "restoring_should_be_repairing",
|
||||
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
|
||||
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
|
||||
"restoring": "restoring",
|
||||
"corruptShouldBeGood": "corrupt_should_be_good",
|
||||
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
|
||||
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
|
||||
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
|
||||
"corrupt": "corrupt",
|
||||
"destroyedShouldBeGood": "destroyed_should_be_good",
|
||||
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
|
||||
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
|
||||
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
|
||||
"destroyed": "destroyed",
|
||||
"scanning": "scanning",
|
||||
"redIerRunning": "red_ier_running",
|
||||
"greenIerBlocked": "green_ier_blocked",
|
||||
"osPatchingDuration": "os_patching_duration",
|
||||
"nodeResetDuration": "node_reset_duration",
|
||||
"nodeBootingDuration": "node_booting_duration",
|
||||
"nodeShutdownDuration": "node_shutdown_duration",
|
||||
"servicePatchingDuration": "service_patching_duration",
|
||||
"fileSystemRepairingLimit": "file_system_repairing_limit",
|
||||
"fileSystemRestoringLimit": "file_system_restoring_limit",
|
||||
"fileSystemScanningLimit": "file_system_scanning_limit",
|
||||
}
|
||||
return key_mapping[legacy_key]
|
||||
15
src/primaite/data_viz/__init__.py
Normal file
15
src/primaite/data_viz/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Utility to generate plots of sessions metrics after PrimAITE."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlotlyTemplate(Enum):
|
||||
"""The built-in plotly templates."""
|
||||
|
||||
PLOTLY = "plotly"
|
||||
PLOTLY_WHITE = "plotly_white"
|
||||
PLOTLY_DARK = "plotly_dark"
|
||||
GGPLOT2 = "ggplot2"
|
||||
SEABORN = "seaborn"
|
||||
SIMPLE_WHITE = "simple_white"
|
||||
NONE = "none"
|
||||
73
src/primaite/data_viz/session_plots.py
Normal file
73
src/primaite/data_viz/session_plots.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
|
||||
def get_plotly_config() -> Dict:
|
||||
"""Get the plotly config from primaite_config.yaml."""
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["session"]["outputs"]["plots"]
|
||||
|
||||
|
||||
def plot_av_reward_per_episode(
|
||||
av_reward_per_episode_csv: Union[str, Path],
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
:param title: The plot title. This is optional.
|
||||
:param subtitle: The plot subtitle. This is optional.
|
||||
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
|
||||
"""
|
||||
df = pl.read_csv(av_reward_per_episode_csv)
|
||||
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["Episode"],
|
||||
y=df["Average Reward"],
|
||||
mode="lines",
|
||||
name="Mean Reward per Episode",
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
"rangeslider": {"visible": config["range_slider"]},
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
2
src/primaite/environment/__init__.py
Normal file
2
src/primaite/environment/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""
|
||||
735
src/primaite/environment/observations.py
Normal file
735
src/primaite/environment/observations.py
Normal file
@@ -0,0 +1,735 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
# This dependency is only needed for type hints,
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise observation component.
|
||||
|
||||
:param env: Primaite training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
_LOGGER.info(f"Initialising {self} observation component")
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
self.structure: List[str]
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""
|
||||
Table with nodes and links as rows and hardware/software status as cols.
|
||||
|
||||
This will create the observation space formatted as a table of integers.
|
||||
There is one row per node, followed by one row per link.
|
||||
The number of columns is 4 plus one per service. They are:
|
||||
|
||||
* node/link ID
|
||||
* node hardware status / 0 for links
|
||||
* node operating system status (if active/service) / 0 for links
|
||||
* node file system status (active/service only) / 0 for links
|
||||
* node service1 status / traffic load from that service for links
|
||||
* node service2 status / traffic load from that service for links
|
||||
* ...
|
||||
* node serviceN status / traffic load from that service for links
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(12, 7)``
|
||||
"""
|
||||
|
||||
_FIXED_PARAMETERS: int = 4
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
num_items = self.env.num_links + self.env.num_nodes
|
||||
num_columns = self.env.num_services + self._FIXED_PARAMETERS
|
||||
observation_shape = (num_items, num_columns)
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.Box(
|
||||
low=0,
|
||||
high=self._MAX_VAL,
|
||||
shape=observation_shape,
|
||||
dtype=self._DATA_TYPE,
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||
"""
|
||||
item_index = 0
|
||||
nodes = self.env.nodes
|
||||
links = self.env.links
|
||||
# Do nodes first
|
||||
for _, node in nodes.items():
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][service_index] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.env.services_list:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for _, link in links.items():
|
||||
self.current_observation[item_index][0] = int(link.get_id())
|
||||
self.current_observation[item_index][1] = 0
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
|
||||
structure = []
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
node_id = node.node_id
|
||||
node_labels = [
|
||||
f"node_{node_id}_id",
|
||||
f"node_{node_id}_hardware_status",
|
||||
f"node_{node_id}_os_status",
|
||||
f"node_{node_id}_fs_status",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||
|
||||
structure.extend(node_labels)
|
||||
|
||||
for i, link in enumerate(links):
|
||||
link_id = link.id
|
||||
link_labels = [
|
||||
f"link_{link_id}_id",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""
|
||||
Flat list of nodes' hardware, OS, file system, and service states.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
Each node has 3 elements plus 1 per service. It will have the following structure:
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
node1 hardware state,
|
||||
node1 OS state,
|
||||
node1 file system state,
|
||||
node1 service1 state,
|
||||
node1 service2 state,
|
||||
node1 serviceN state (one for each service),
|
||||
node2 hardware state,
|
||||
node2 OS state,
|
||||
node2 file system state,
|
||||
node2 service1 state,
|
||||
node2 service2 state,
|
||||
node2 serviceN state (one for each service),
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeStatuses observation component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
node_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
||||
node_shape = node_shape + services_shape
|
||||
|
||||
shape = node_shape * self.env.num_nodes
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||
"""
|
||||
obs = []
|
||||
for _, node in self.env.nodes.items():
|
||||
hardware_state = node.hardware_state.value
|
||||
software_state = 0
|
||||
file_system_state = 0
|
||||
service_states = [0] * self.env.num_services
|
||||
|
||||
if isinstance(node, ActiveNode):
|
||||
software_state = node.software_state.value
|
||||
file_system_state = node.file_system_state_observed.value
|
||||
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*service_states,
|
||||
]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
for state in HardwareState:
|
||||
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_software_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||
for state in FileSystemState:
|
||||
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
|
||||
return structure
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""
|
||||
Flat list of traffic levels encoded into banded categories.
|
||||
|
||||
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||
|
||||
* 0 = No traffic (0% of bandwidth)
|
||||
* 1 = No traffic (0%-33% of bandwidth)
|
||||
* 2 = No traffic (33%-66% of bandwidth)
|
||||
* 3 = No traffic (66%-100% of bandwidth)
|
||||
* 4 = No traffic (100% of bandwidth)
|
||||
|
||||
.. note::
|
||||
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
|
||||
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
|
||||
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a LinkTrafficLevels observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
|
||||
defaults to False
|
||||
:type combine_service_traffic: bool, optional
|
||||
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical
|
||||
value, defaults to 5
|
||||
:type quantisation_levels: int, optional
|
||||
"""
|
||||
if quantisation_levels < 3:
|
||||
_msg = (
|
||||
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
||||
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
||||
f"Resetting to default value (5)"
|
||||
)
|
||||
_LOGGER.warning(_msg)
|
||||
quantisation_levels = 5
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
self._combine_service_traffic: bool = combine_service_traffic
|
||||
self._quantisation_levels: int = quantisation_levels
|
||||
self._entries_per_link: int = 1
|
||||
|
||||
if not self._combine_service_traffic:
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||
"""
|
||||
obs = []
|
||||
for _, link in self.env.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
traffic_level = 0
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
link_id = link.id
|
||||
if self._combine_service_traffic:
|
||||
protocols = ["overall"]
|
||||
else:
|
||||
protocols = [protocol.name for protocol in link.protocol_list]
|
||||
|
||||
for p in protocols:
|
||||
for i in range(self._quantisation_levels):
|
||||
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||
return structure
|
||||
|
||||
|
||||
class AccessControlList(AbstractObservationComponent):
|
||||
"""Flat list of all the Access Control Rules in the Access Control List.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
acl_rule1 permission,
|
||||
acl_rule1 source_ip,
|
||||
acl_rule1 dest_ip,
|
||||
acl_rule1 protocol,
|
||||
acl_rule1 port,
|
||||
acl_rule1 position,
|
||||
acl_rule2 permission,
|
||||
acl_rule2 source_ip,
|
||||
acl_rule2 dest_ip,
|
||||
acl_rule2 protocol,
|
||||
acl_rule2 port,
|
||||
acl_rule2 position,
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
source_ip = acl_rule.source_ip
|
||||
dest_ip = acl_rule.dest_ip
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
acl_rule_id = self.env.acl.acl.index(acl_rule)
|
||||
|
||||
for permission in RulePermissionType:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
|
||||
for service in self.env.services_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
|
||||
for port in self.env.ports_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
|
||||
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
|
||||
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
||||
further parameters to make them more flexible.
|
||||
"""
|
||||
|
||||
_REGISTRY: Final[Dict[str, type]] = {
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the observation handler."""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
# internal the observation space (unflattened version of space if flatten=True)
|
||||
self._space: spaces.Space
|
||||
# flattened version of the observation space
|
||||
self._flat_space: spaces.Space
|
||||
|
||||
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
if len(current_obs) == 1:
|
||||
self._observation = current_obs[0]
|
||||
else:
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
:param obs_component: The component to add.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
:param obs_component: Which component to remove. It must exist within this object's
|
||||
``registered_obs_components`` attribute.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self) -> None:
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# if there are multiple components, build a composite tuple space
|
||||
if len(component_spaces) == 1:
|
||||
self._space = component_spaces[0]
|
||||
else:
|
||||
self._space = spaces.Tuple(component_spaces)
|
||||
if len(component_spaces) > 0:
|
||||
self._flat_space = spaces.flatten_space(self._space)
|
||||
else:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
The expected format for the config dictionary is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = {
|
||||
components: [
|
||||
{
|
||||
"name": "<COMPONENT1_NAME>"
|
||||
},
|
||||
{
|
||||
"name": "<COMPONENT2_NAME>"
|
||||
"options": {"opt1": val1, "opt2": val2}
|
||||
},
|
||||
{
|
||||
...
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
:return: Observation handler
|
||||
:rtype: primaite.environment.observations.ObservationsHandler
|
||||
"""
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
comp_class = cls._REGISTRY[comp_type]
|
||||
|
||||
# Create the component with options from the YAML
|
||||
options = component_cfg.get("options") or {}
|
||||
component = comp_class(env, **options)
|
||||
|
||||
handler.register(component)
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self) -> List[str]:
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||
|
||||
labels = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
labels.extend(obs_comp.structure)
|
||||
|
||||
return labels
|
||||
1408
src/primaite/environment/primaite_env.py
Normal file
1408
src/primaite/environment/primaite_env.py
Normal file
File diff suppressed because it is too large
Load Diff
386
src/primaite/environment/reward.py
Normal file
386
src/primaite/environment/reward.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements reward function."""
|
||||
from logging import Logger
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes: Dict[str, NodeUnion],
|
||||
final_nodes: Dict[str, NodeUnion],
|
||||
reference_nodes: Dict[str, NodeUnion],
|
||||
green_iers: Dict[str, "IER"],
|
||||
green_iers_reference: Dict[str, "IER"],
|
||||
red_iers: Dict[str, "IER"],
|
||||
step_count: int,
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
final_nodes: The nodes after red and blue agents take effect
|
||||
reference_nodes: The nodes if there had been no red or blue effect
|
||||
green_iers: The green IERs (should be running)
|
||||
red_iers: Should be stopeed (ideally) by the blue agent
|
||||
step_count: current step
|
||||
config_values: Config values
|
||||
"""
|
||||
reward_value: float = 0.0
|
||||
|
||||
# For each node, compare hardware state, SoftwareState, service states
|
||||
for node_key, final_node in final_nodes.items():
|
||||
initial_node = initial_nodes[node_key]
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Hardware State
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Service State
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if ier_value.get_is_running():
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_operating_state = final_node.hardware_state
|
||||
reference_node_operating_state = reference_node.hardware_state
|
||||
|
||||
if final_node_operating_state == reference_node_operating_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final (current) state of node (i.e. at every step)
|
||||
if reference_node_operating_state == HardwareState.ON:
|
||||
if final_node_operating_state == HardwareState.OFF:
|
||||
score += config_values.off_should_be_on
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting_should_be_on
|
||||
else:
|
||||
pass
|
||||
elif reference_node_operating_state == HardwareState.OFF:
|
||||
if final_node_operating_state == HardwareState.ON:
|
||||
score += config_values.on_should_be_off
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting_should_be_off
|
||||
else:
|
||||
pass
|
||||
elif reference_node_operating_state == HardwareState.RESETTING:
|
||||
if final_node_operating_state == HardwareState.ON:
|
||||
score += config_values.on_should_be_resetting
|
||||
elif final_node_operating_state == HardwareState.OFF:
|
||||
score += config_values.off_should_be_resetting
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_os_state = final_node.software_state
|
||||
reference_node_os_state = reference_node.software_state
|
||||
|
||||
if final_node_os_state == reference_node_os_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final (current) state of node (i.e. at every step)
|
||||
if reference_node_os_state == SoftwareState.GOOD:
|
||||
if final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_node_os_state == SoftwareState.PATCHING:
|
||||
if final_node_os_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif reference_node_os_state == SoftwareState.COMPROMISED:
|
||||
if final_node_os_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(
|
||||
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_services: Dict[str, Service] = final_node.services
|
||||
reference_node_services: Dict[str, Service] = reference_node.services
|
||||
|
||||
for service_key, final_service in final_node_services.items():
|
||||
reference_service = reference_node_services[service_key]
|
||||
final_service = final_node_services[service_key]
|
||||
|
||||
if final_service.software_state == reference_service.software_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final state of node (i.e. at every step)
|
||||
if reference_service.software_state == SoftwareState.GOOD:
|
||||
if final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.PATCHING:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.COMPROMISED:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_compromised
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.OVERWHELMED:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_file_system_state = final_node.file_system_state_actual
|
||||
reference_node_file_system_state = reference_node.file_system_state_actual
|
||||
|
||||
final_node_scanning_state = final_node.file_system_scanning
|
||||
reference_node_scanning_state = reference_node.file_system_scanning
|
||||
|
||||
# File System State
|
||||
if final_node_file_system_state == reference_node_file_system_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final state of node (i.e. at every step)
|
||||
if reference_node_file_system_state == FileSystemState.GOOD:
|
||||
if final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.REPAIRING:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.RESTORING:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.CORRUPT:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.DESTROYED:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Scanning State
|
||||
if final_node_scanning_state == reference_node_scanning_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# We're scanning the file system which incurs a penalty (as it slows down systems)
|
||||
score += config_values.scanning
|
||||
|
||||
return score
|
||||
11
src/primaite/exceptions.py
Normal file
11
src/primaite/exceptions.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
class PrimaiteError(Exception):
|
||||
"""The root PrimAITe Error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RLlibAgentError(PrimaiteError):
|
||||
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
|
||||
|
||||
pass
|
||||
2
src/primaite/links/__init__.py
Normal file
2
src/primaite/links/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Network connections between nodes in the simulation."""
|
||||
114
src/primaite/links/link.py
Normal file
114
src/primaite/links/link.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The link class."""
|
||||
from typing import List
|
||||
|
||||
from primaite.common.protocol import Protocol
|
||||
|
||||
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
|
||||
"""
|
||||
Initialise a Link within the simulated network.
|
||||
|
||||
:param _id: The IER id
|
||||
:param _bandwidth: The bandwidth of the link (bps)
|
||||
:param _source_node_name: The name of the source node
|
||||
:param _dest_node_name: The name of the destination node
|
||||
:param _protocols: The protocols to add to the link
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.bandwidth: int = _bandwidth
|
||||
self.source_node_name: str = _source_node_name
|
||||
self.dest_node_name: str = _dest_node_name
|
||||
self.protocol_list: List[Protocol] = []
|
||||
|
||||
# Add the default protocols
|
||||
for protocol_name in _services:
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol: str) -> None:
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
"""
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self) -> str:
|
||||
"""
|
||||
Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
"""
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self) -> str:
|
||||
"""
|
||||
Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
"""
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self) -> int:
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
"""
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self) -> List[Protocol]:
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
"""
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self) -> int:
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
"""
|
||||
total_load = 0
|
||||
for protocol in self.protocol_list:
|
||||
total_load += protocol.get_load()
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol: str, _load: int) -> None:
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
_load: The amount to load (bps)
|
||||
"""
|
||||
for protocol in self.protocol_list:
|
||||
if protocol.get_name() == _protocol:
|
||||
protocol.add_load(_load)
|
||||
else:
|
||||
pass
|
||||
|
||||
def clear_traffic(self) -> None:
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
57
src/primaite/main.py
Normal file
57
src/primaite/main.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
session = PrimaiteSession(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
2
src/primaite/nodes/__init__.py
Normal file
2
src/primaite/nodes/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Nodes represent network hosts in the simulation."""
|
||||
208
src/primaite/nodes/active_node.py
Normal file
208
src/primaite/nodes/active_node.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActiveNode(Node):
|
||||
"""Active Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an active node.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
:param node_type: The node type (enum)
|
||||
:param priority: The node priority (enum)
|
||||
:param hardware_state: The node Hardware State
|
||||
:param ip_address: The node IP address
|
||||
:param software_state: The node Software State
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
self.ip_address: str = ip_address
|
||||
# Related to Software
|
||||
self._software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
# Related to File System
|
||||
self.file_system_state_actual: FileSystemState = file_system_state
|
||||
self.file_system_state_observed: FileSystemState = file_system_state
|
||||
self.file_system_scanning: bool = False
|
||||
self.file_system_scanning_count: int = 0
|
||||
self.file_system_action_count: int = 0
|
||||
|
||||
@property
|
||||
def software_state(self) -> SoftwareState:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:return: The software_state.
|
||||
"""
|
||||
return self._software_state
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:param software_state: Software State.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be "
|
||||
f"changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
Args:
|
||||
software_state: Software State
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
if self._software_state != SoftwareState.COMPROMISED:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be changed."
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def update_os_patching_status(self) -> None:
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so File System State "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
if (
|
||||
self.file_system_state_actual != FileSystemState.CORRUPT
|
||||
and self.file_system_state_actual != FileSystemState.DESTROYED
|
||||
):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so File System State (if not "
|
||||
f"compromised) cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def start_file_system_scan(self) -> None:
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self) -> None:
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
self.file_system_scanning_count -= 1
|
||||
|
||||
# Reparing / Restoring updates
|
||||
if self.file_system_action_count <= 0:
|
||||
self.file_system_action_count = 0
|
||||
if (
|
||||
self.file_system_state_actual == FileSystemState.REPAIRING
|
||||
or self.file_system_state_actual == FileSystemState.RESTORING
|
||||
):
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
# Scanning updates
|
||||
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
|
||||
self.file_system_state_observed = self.file_system_state_actual
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the reset count & makes software and file state to GOOD."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting software and file state to GOOD."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
79
src/primaite/nodes/node.py
Normal file
79
src/primaite/nodes/node.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The base Node class."""
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
|
||||
class Node:
|
||||
"""Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a node.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
:param node_type: The type of the node.
|
||||
:param priority: The priority of the node.
|
||||
:param hardware_state: The state of the node.
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
self.node_id: Final[str] = node_id
|
||||
self.name: Final[str] = name
|
||||
self.node_type: Final[NodeType] = node_type
|
||||
self.priority = priority
|
||||
self.hardware_state: HardwareState = hardware_state
|
||||
self.resetting_count: int = 0
|
||||
self.config_values: TrainingConfig = config_values
|
||||
self.booting_count: int = 0
|
||||
self.shutting_down_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def turn_on(self) -> None:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self) -> None:
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
self.shutting_down_count = self.config_values.node_shutdown_duration
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.hardware_state = HardwareState.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self) -> None:
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
self.shutting_down_count = 0
|
||||
self.hardware_state = HardwareState.OFF
|
||||
94
src/primaite/nodes/node_state_instruction_green.py
Normal file
94
src/primaite/nodes/node_state_instruction_green.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_node_id: str,
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
|
||||
:param _id: The node state instruction id
|
||||
:param _start_step: The start step of the instruction
|
||||
:param _end_step: The end step of the instruction
|
||||
:param _node_id: The id of the associated node
|
||||
:param _node_pol_type: The pattern of life type
|
||||
:param _service_name: The service name
|
||||
:param _state: The state (node or service)
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type: "NodePOLType" = _node_pol_type
|
||||
self.service_name: str = _service_name # Not used when not a service instruction
|
||||
# TODO: confirm type of state
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self) -> "NodePOLType":
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
return self.state
|
||||
143
src/primaite/nodes/node_state_instruction_red.py
Normal file
143
src/primaite/nodes/node_state_instruction_red.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionRed:
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_target_node_id: str,
|
||||
_pol_initiator: "NodePOLInitiator",
|
||||
_pol_type: NodePOLType,
|
||||
pol_protocol: str,
|
||||
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
_pol_source_node_id: str,
|
||||
_pol_source_node_service: str,
|
||||
_pol_source_node_service_state: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction for the red agent.
|
||||
|
||||
:param _id: The node state instruction id
|
||||
:param _start_step: The start step of the instruction
|
||||
:param _end_step: The end step of the instruction
|
||||
:param _target_node_id: The id of the associated node
|
||||
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
|
||||
:param _pol_type: The pattern of life type
|
||||
:param pol_protocol: The pattern of life protocol/service affected
|
||||
:param _pol_state: The state (node or service)
|
||||
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.target_node_id: str = _target_node_id
|
||||
self.initiator: "NodePOLInitiator" = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name: str = pol_protocol # Not used when not a service instruction
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
|
||||
self.source_node_id: str = _pol_source_node_id
|
||||
self.source_node_service: str = _pol_source_node_service
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self) -> "NodePOLInitiator":
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
"""
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self) -> NodePOLType:
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self) -> str:
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
"""
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self) -> str:
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
"""
|
||||
return self.source_node_service_state
|
||||
42
src/primaite/nodes/passive_node.py
Normal file
42
src/primaite/nodes/passive_node.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
|
||||
class PassiveNode(Node):
|
||||
"""The Passive Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a passive node.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
:param node_type: The type of the node.
|
||||
:param priority: The priority of the node.
|
||||
:param hardware_state: The state of the node.
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
# Pass through to Super for now
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
"""
|
||||
Gets the node IP address as an empty string.
|
||||
|
||||
No concept of IP address for passive nodes for now.
|
||||
|
||||
:return: The node IP address.
|
||||
"""
|
||||
return ""
|
||||
190
src/primaite/nodes/service_node.py
Normal file
190
src/primaite/nodes/service_node.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceNode(ActiveNode):
|
||||
"""ServiceNode class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a Service Node.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
:param node_type: The node type (enum)
|
||||
:param priority: The node priority (enum)
|
||||
:param hardware_state: The node Hardware State
|
||||
:param ip_address: The node IP address
|
||||
:param software_state: The node Software State
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(
|
||||
node_id,
|
||||
name,
|
||||
node_type,
|
||||
priority,
|
||||
hardware_state,
|
||||
ip_address,
|
||||
software_state,
|
||||
file_system_state,
|
||||
config_values,
|
||||
)
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service) -> None:
|
||||
"""
|
||||
Adds a service to the node.
|
||||
|
||||
:param service: The service to add
|
||||
"""
|
||||
self.services[service.name] = service
|
||||
|
||||
def has_service(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is on a node.
|
||||
|
||||
:param protocol_name: The service (protocol)e.
|
||||
:return: True if service (protocol) is on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def service_running(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
if service_value.software_state != SoftwareState.PATCHING:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
if service_value.software_state == SoftwareState.OVERWHELMED:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
# Can't set to compromised if you're in a patching state
|
||||
if (
|
||||
software_state == SoftwareState.COMPROMISED
|
||||
and service_value.software_state != SoftwareState.PATCHING
|
||||
) or software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.services[<key>]:{protocol_name}, "
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
Done if the software_state is not "compromised".
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
if service_value.software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.services[<key>]:{protocol_name}, "
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name: str) -> SoftwareState:
|
||||
"""
|
||||
Gets the state of a service.
|
||||
|
||||
:return: The software_state of the service.
|
||||
"""
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
return service_value.software_state
|
||||
|
||||
def update_services_patching_status(self) -> None:
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.software_state == SoftwareState.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
34
src/primaite/notebooks/__init__.py
Normal file
34
src/primaite/notebooks/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from logging import Logger
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session() -> None:
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
Currently only works on Windows OS.
|
||||
|
||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||
"""
|
||||
if importlib.util.find_spec("jupyter") is not None:
|
||||
jupyter_cmd = "python3 -m jupyter lab"
|
||||
if sys.platform == "win32":
|
||||
jupyter_cmd = "jupyter lab"
|
||||
|
||||
working_dir = os.getcwd()
|
||||
os.chdir(PRIMAITE_PATHS.user_notebooks_path)
|
||||
subprocess.Popen(jupyter_cmd)
|
||||
os.chdir(working_dir)
|
||||
else:
|
||||
# Jupyter is not installed
|
||||
_LOGGER.error("Cannot start jupyter lab as it is not installed")
|
||||
2
src/primaite/pol/__init__.py
Normal file
2
src/primaite/pol/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Pattern of Life- Represents the actions of users on the network."""
|
||||
264
src/primaite/pol/green_pol.py
Normal file
264
src/primaite/pol/green_pol.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_iers(
|
||||
network: MultiGraph,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
links: Dict[str, Link],
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
nodes: The nodes within the environment
|
||||
links: The links within the environment
|
||||
iers: The IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying IERs")
|
||||
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
protocol = ier_value.get_protocol()
|
||||
port = ier_value.get_port()
|
||||
load = ier_value.get_load()
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
dest_valid = True
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
source_node.hardware_state == HardwareState.ON
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
source_valid = False
|
||||
elif source_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
# TO DO
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if (
|
||||
source_node.hardware_state == HardwareState.ON
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
dest_valid = False
|
||||
elif dest_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.ip_address
|
||||
+ ", dest: "
|
||||
+ dest_node.ip_address
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
|
||||
# Check whether both the source and destination are valid, and there's no ACL block
|
||||
if source_valid and dest_valid and not acl_block:
|
||||
# Load up the link(s) with the traffic
|
||||
|
||||
if _VERBOSE:
|
||||
print("Source, Dest and ACL valid")
|
||||
|
||||
# Get the shortest path (i.e. nodes) between source and destination
|
||||
path_node_list = shortest_path(network, source_node, dest_node)
|
||||
path_node_list_length = len(path_node_list)
|
||||
path_valid = True
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
# Now apply the new loads to the links
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
node_pol: The node pattern of life to apply
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying Node PoL")
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
node_id = node_instruction.get_node_id()
|
||||
node_pol_type = node_instruction.get_node_pol_type()
|
||||
service_name = node_instruction.get_service_name()
|
||||
state = node_instruction.get_state()
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
node = nodes[node_id]
|
||||
|
||||
if node_pol_type == NodePOLType.OPERATING:
|
||||
# Change hardware state
|
||||
node.hardware_state = state
|
||||
elif node_pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_software_state_if_not_compromised(state)
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ServiceNode):
|
||||
node.set_service_state_if_not_compromised(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
147
src/primaite/pol/ier.py
Normal file
147
src/primaite/pol/ier.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
|
||||
|
||||
class IER(object):
|
||||
"""Information Exchange Requirement class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_load: int,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_source_node_id: str,
|
||||
_dest_node_id: str,
|
||||
_mission_criticality: int,
|
||||
_running: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an Information Exchange Request.
|
||||
|
||||
:param _id: The IER id
|
||||
:param _start_step: The step when this IER should start
|
||||
:param _end_step: The step when this IER should end
|
||||
:param _load: The load this IER should put on a link (bps)
|
||||
:param _protocol: The protocol of this IER
|
||||
:param _port: The port this IER runs on
|
||||
:param _source_node_id: The source node ID
|
||||
:param _dest_node_id: The destination node ID
|
||||
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
:param _running: Indicates whether the IER is currently running
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.source_node_id: str = _source_node_id
|
||||
self.dest_node_id: str = _dest_node_id
|
||||
self.load: int = _load
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
self.mission_criticality: int = _mission_criticality
|
||||
self.running: bool = _running
|
||||
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
"""
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
"""
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self) -> bool:
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
"""
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value: bool) -> None:
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
"""
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self) -> int:
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
"""
|
||||
return self.mission_criticality
|
||||
353
src/primaite/pol/red_agent_pol.py
Normal file
353
src/primaite/pol/red_agent_pol.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(
|
||||
network: MultiGraph,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
links: Dict[str, Link],
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
nodes: The nodes within the environment
|
||||
links: The links within the environment
|
||||
iers: The red agent IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number.
|
||||
"""
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
protocol = ier_value.get_protocol()
|
||||
port = ier_value.get_port()
|
||||
load = ier_value.get_load()
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
dest_valid = True
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
source_valid = False
|
||||
elif source_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
# TO DO
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||
# to change according to duck typing.
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.hardware_state == HardwareState.ON:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
dest_valid = False
|
||||
elif dest_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.hardware_state == HardwareState.ON:
|
||||
if dest_node.has_service(protocol):
|
||||
# We don't care what state the destination service is in for an IER
|
||||
dest_valid = True
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.ip_address
|
||||
+ ", dest: "
|
||||
+ dest_node.ip_address
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
|
||||
# Check whether both the source and destination are valid, and there's no ACL block
|
||||
if source_valid and dest_valid and not acl_block:
|
||||
# Load up the link(s) with the traffic
|
||||
|
||||
if _VERBOSE:
|
||||
print("Source, Dest and ACL valid")
|
||||
|
||||
# Get the shortest path (i.e. nodes) between source and destination
|
||||
path_node_list = shortest_path(network, source_node, dest_node)
|
||||
path_node_list_length = len(path_node_list)
|
||||
path_valid = True
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
# We're assuming here that red agents can get past switches that are patching
|
||||
for node in path_node_list:
|
||||
if node.hardware_state != HardwareState.ON:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
# Now apply the new loads to the links
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
print("Red IER was allowed to run in step " + str(step))
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def apply_red_agent_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
iers: Dict[str, IER],
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
iers: The red agent IERs
|
||||
node_pol: The red agent node pattern of life to apply
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying Node Red Agent PoL")
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
target_node_id = node_instruction.get_target_node_id()
|
||||
initiator = node_instruction.get_initiator()
|
||||
pol_type = node_instruction.get_pol_type()
|
||||
service_name = node_instruction.get_service_name()
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
|
||||
passed_checks = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
target_node: NodeUnion = nodes[target_node_id]
|
||||
|
||||
# check if the initiator type is a str, and if so, cast it as
|
||||
# NodePOLInitiator
|
||||
if isinstance(initiator, str):
|
||||
initiator = NodePOLInitiator[initiator]
|
||||
|
||||
# Based the action taken on the initiator type
|
||||
if initiator == NodePOLInitiator.DIRECT:
|
||||
# No conditions required, just apply the change
|
||||
passed_checks = True
|
||||
elif initiator == NodePOLInitiator.IER:
|
||||
# Need to check there is a red IER incoming
|
||||
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
|
||||
elif initiator == NodePOLInitiator.SERVICE:
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
if source_node.has_service(source_node_service_name):
|
||||
if (
|
||||
source_node.get_service_state(source_node_service_name)
|
||||
== SoftwareState[source_node_service_state_value]
|
||||
):
|
||||
passed_checks = True
|
||||
else:
|
||||
# Do nothing, no matching state value
|
||||
pass
|
||||
else:
|
||||
# Do nothing, service not on this node
|
||||
pass
|
||||
else:
|
||||
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
|
||||
|
||||
# Only apply the PoL if the checks have passed (based on the initiator type)
|
||||
if passed_checks:
|
||||
# Apply the change
|
||||
if pol_type == NodePOLType.OPERATING:
|
||||
# Change hardware state
|
||||
target_node.hardware_state = state
|
||||
elif pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.software_state = state
|
||||
elif pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
if isinstance(target_node, ServiceNode):
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
|
||||
def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool:
|
||||
"""Checks if the RED IER is incoming.
|
||||
|
||||
:param node: Destination node of the IER
|
||||
:type node: NodeUnion
|
||||
:param iers: Directory of IERs
|
||||
:type iers: Dict[str,IER]
|
||||
:param node_pol_type: Type of Pattern-Of-Life
|
||||
:type node_pol_type: NodePOLType
|
||||
:return: Whether the RED IER is incoming.
|
||||
:rtype: bool
|
||||
"""
|
||||
node_id = node.node_id
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if (
|
||||
node_pol_type == NodePOLType.OPERATING
|
||||
or node_pol_type == NodePOLType.OS
|
||||
or node_pol_type == NodePOLType.FILE
|
||||
):
|
||||
# It's looking to change hardware state, file system or SoftwareState, so valid
|
||||
return True
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Check if the service is present on the node and running
|
||||
ier_protocol = ier_value.get_protocol()
|
||||
if isinstance(node, ServiceNode):
|
||||
if node.has_service(ier_protocol):
|
||||
if node.service_running(ier_protocol):
|
||||
# Matching service is present and running, so valid
|
||||
return True
|
||||
else:
|
||||
# Service is present, but not running
|
||||
return False
|
||||
else:
|
||||
# Service is not present
|
||||
return False
|
||||
else:
|
||||
# Not a service node
|
||||
return False
|
||||
else:
|
||||
# Shouldn't get here - instruction type is undefined
|
||||
return False
|
||||
else:
|
||||
# The IER destination is not this node, or the IER is not running
|
||||
return False
|
||||
225
src/primaite/primaite_session.py
Normal file
225
src/primaite/primaite_session.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Tuple, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession class.
|
||||
|
||||
Provides a single learning and evaluation entry point for all training and lay down configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
# set load_session to true
|
||||
self.is_load_session = True
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
|
||||
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path, legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path,
|
||||
self.session_path,
|
||||
self.legacy_training_config,
|
||||
self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
raise ValueError
|
||||
|
||||
self.session_path: Path = self._agent_session.session_path
|
||||
self.timestamp_str: str = self._agent_session.timestamp_str
|
||||
self.learning_path: Path = self._agent_session.learning_path
|
||||
self.evaluation_path: Path = self._agent_session.evaluation_path
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(**kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the learn av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the eval av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the learn all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the eval all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def metadata_file_as_dict(self) -> Dict[str, Any]:
|
||||
"""Read the session_metadata.json file and return as a dict."""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
return json.load(file)
|
||||
2
src/primaite/setup/__init__.py
Normal file
2
src/primaite/setup/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Utilities to prepare the user's data folders."""
|
||||
22
src/primaite/setup/_package_data/primaite_config.yaml
Normal file
22
src/primaite/setup/_package_data/primaite_config.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
# The main PrimAITE application config file
|
||||
|
||||
# Logging
|
||||
logging:
|
||||
log_level: INFO
|
||||
logger_format:
|
||||
DEBUG: '%(asctime)s: %(message)s'
|
||||
INFO: '%(asctime)s: %(message)s'
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
|
||||
# Session
|
||||
session:
|
||||
outputs:
|
||||
plots:
|
||||
size:
|
||||
auto_size: false
|
||||
width: 1500
|
||||
height: 900
|
||||
template: plotly_white
|
||||
range_slider: false
|
||||
14
src/primaite/setup/old_installation_clean_up.py
Normal file
14
src/primaite/setup/old_installation_clean_up.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
35
src/primaite/setup/reset_demo_notebooks.py
Normal file
35
src/primaite/setup/reset_demo_notebooks.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
|
||||
"""
|
||||
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
|
||||
for subdir, dirs, files in os.walk(notebooks_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
|
||||
target_fp = PRIMAITE_PATHS.user_notebooks_path / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
if overwrite_existing and not copy_file:
|
||||
copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp))
|
||||
|
||||
if copy_file:
|
||||
shutil.copy2(fp, target_fp)
|
||||
_LOGGER.info(f"Reset example notebook: {target_fp}")
|
||||
35
src/primaite/setup/reset_example_configs.py
Normal file
35
src/primaite/setup/reset_example_configs.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
|
||||
"""
|
||||
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
|
||||
|
||||
for subdir, dirs, files in os.walk(configs_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
|
||||
target_fp = PRIMAITE_PATHS.user_config_path / "example_config" / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
if overwrite_existing and not copy_file:
|
||||
copy_file = not filecmp.cmp(fp, target_fp)
|
||||
|
||||
if copy_file:
|
||||
shutil.copy2(fp, target_fp)
|
||||
_LOGGER.info(f"Reset example config: {target_fp}")
|
||||
2
src/primaite/transactions/__init__.py
Normal file
2
src/primaite/transactions/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Record data of the system's state and agent's observations and actions."""
|
||||
102
src/primaite/transactions/transaction.py
Normal file
102
src/primaite/transactions/transaction.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
:param agent_identifier: An identifier for the agent in use
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp: datetime = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
self.episode_number: int = episode_number
|
||||
"The episode number"
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space: "spaces.Space" = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward: Optional[float] = None
|
||||
"The reward value"
|
||||
self.action_space: Optional[int] = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description: Optional[List[str]] = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
:return: A tuple consisting of (header, data).
|
||||
"""
|
||||
if isinstance(self.action_space, int):
|
||||
action_length = self.action_space
|
||||
else:
|
||||
action_length = self.action_space.size
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + self.obs_space_description
|
||||
|
||||
row = [
|
||||
str(self.timestamp),
|
||||
str(self.episode_number),
|
||||
str(self.step_number),
|
||||
str(self.reward),
|
||||
]
|
||||
row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist()
|
||||
return header, row
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
:param action_space: The action space
|
||||
:return: The action space as an array of strings
|
||||
"""
|
||||
if isinstance(action_space, list):
|
||||
return [str(i) for i in action_space]
|
||||
else:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
:param obs_space: The observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
:param obs_features: The number of features associated with the asset
|
||||
:return: The observation space as an array of strings
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
return_array.append(str(obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
2
src/primaite/utils/__init__.py
Normal file
2
src/primaite/utils/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Utilities for PrimAITE."""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user