Merge remote-tracking branch 'origin/dev' into feature/2769-implement-user-account-action-space
This commit is contained in:
@@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
main port they're assigned.
|
||||
|
||||
### Changed
|
||||
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
|
||||
|
||||
### Fixed
|
||||
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
|
||||
- Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install`
|
||||
and `uninstall` methods in the `Node` class.
|
||||
- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the
|
||||
|
||||
@@ -23,8 +23,10 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
"""Name of the file, used for querying simulation state dictionary."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to the file in the observation."""
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool) -> None:
|
||||
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
|
||||
"""
|
||||
Initialise a file observation instance.
|
||||
|
||||
@@ -34,9 +36,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
:type where: WhereType
|
||||
:param include_num_access: Whether to include the number of accesses to the file in the observation.
|
||||
:type include_num_access: bool
|
||||
:param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False,
|
||||
the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.include_num_access: bool = include_num_access
|
||||
self.file_system_requires_scan: bool = file_system_requires_scan
|
||||
|
||||
self.default_observation: ObsType = {"health_status": 0}
|
||||
if self.include_num_access:
|
||||
@@ -74,7 +80,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {"health_status": file_state["visible_status"]}
|
||||
if self.file_system_requires_scan:
|
||||
health_status = file_state["visible_status"]
|
||||
else:
|
||||
health_status = file_state["health_status"]
|
||||
obs = {"health_status": health_status}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
|
||||
return obs
|
||||
@@ -104,8 +114,15 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed file observation instance.
|
||||
:rtype: FileObservation
|
||||
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
|
||||
the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
|
||||
return cls(
|
||||
where=parent_where + ["files", config.file_name],
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
)
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
@@ -122,9 +139,16 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
"""Number of spaces for file observations in this folder."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether files in this folder should include the number of accesses in their observation."""
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(
|
||||
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
|
||||
self,
|
||||
where: WhereType,
|
||||
files: Iterable[FileObservation],
|
||||
num_files: int,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
@@ -138,12 +162,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
:type num_files: int
|
||||
:param include_num_access: Whether to include the number of accesses to files in the observation.
|
||||
:type include_num_access: bool
|
||||
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
|
||||
the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.file_system_requires_scan: bool = file_system_requires_scan
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
|
||||
self.files.append(
|
||||
FileObservation(
|
||||
where=None,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=self.file_system_requires_scan,
|
||||
)
|
||||
)
|
||||
while len(self.files) > num_files:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
@@ -168,7 +203,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
health_status = folder_state["health_status"]
|
||||
if self.file_system_requires_scan:
|
||||
health_status = folder_state["visible_status"]
|
||||
else:
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
@@ -209,6 +247,13 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
# pass down shared/common config items
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
file_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
return cls(
|
||||
where=where,
|
||||
files=files,
|
||||
num_files=config.num_files,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
)
|
||||
|
||||
@@ -48,6 +48,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""A dict containing which traffic types are to be included in the observation."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to files observations on this host."""
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -64,6 +68,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
include_nmne: bool,
|
||||
monitored_traffic: Optional[Dict],
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a host observation instance.
|
||||
@@ -95,6 +100,9 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
:type monitored_traffic: Dict
|
||||
:param include_num_access: Flag to include the number of accesses to files.
|
||||
:type include_num_access: bool
|
||||
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
@@ -120,7 +128,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
self.folders: List[FolderObservation] = folders
|
||||
while len(self.folders) < num_folders:
|
||||
self.folders.append(
|
||||
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
|
||||
FolderObservation(
|
||||
where=None,
|
||||
files=[],
|
||||
num_files=num_files,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=file_system_requires_scan,
|
||||
)
|
||||
)
|
||||
while len(self.folders) > num_folders:
|
||||
truncated_folder = self.folders.pop()
|
||||
@@ -226,6 +240,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
for folder_config in config.folders:
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
folder_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
@@ -257,4 +272,5 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
)
|
||||
|
||||
@@ -44,6 +44,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""A dict containing which traffic types are to be included in the observation."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Flag to include the number of accesses."""
|
||||
file_system_requires_scan: bool = True
|
||||
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of ports."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
@@ -187,6 +189,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
host_config.monitored_traffic = config.monitored_traffic
|
||||
if host_config.include_num_access is None:
|
||||
host_config.include_num_access = config.include_num_access
|
||||
if host_config.file_system_requires_scan is None:
|
||||
host_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
|
||||
@@ -26,6 +26,7 @@ def test_file_observation(simulation):
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
)
|
||||
|
||||
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
|
||||
@@ -53,6 +54,7 @@ def test_folder_observation(simulation):
|
||||
root_folder_obs = FolderObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
num_files=1,
|
||||
files=[],
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ def test_host_observation(simulation):
|
||||
applications=[],
|
||||
folders=[],
|
||||
network_interfaces=[],
|
||||
file_system_requires_scan=True,
|
||||
)
|
||||
|
||||
assert host_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
|
||||
@@ -17,6 +17,7 @@ def test_file_observation():
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
)
|
||||
assert dog_file_obs.observe(state) == {"health_status": 1}
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
132
tests/unit_tests/_primaite/_game/_agent/test_observations.py
Normal file
132
tests/unit_tests/_primaite/_game/_agent/test_observations.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
|
||||
|
||||
class TestFileSystemRequiresScan:
|
||||
@pytest.mark.parametrize(
|
||||
("yaml_option_string", "expected_val"),
|
||||
(
|
||||
("file_system_requires_scan: true", True),
|
||||
("file_system_requires_scan: false", False),
|
||||
(" ", True),
|
||||
),
|
||||
)
|
||||
def test_obs_config(self, yaml_option_string, expected_val):
|
||||
"""Check that the default behaviour is to set FileSystemRequiresScan to True."""
|
||||
obs_cfg_yaml = f"""
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
{yaml_option_string}
|
||||
include_nmne: true
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {{}}
|
||||
|
||||
"""
|
||||
|
||||
cfg = yaml.safe_load(obs_cfg_yaml)
|
||||
manager = ObservationManager.from_config(cfg)
|
||||
|
||||
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
|
||||
for i, host in enumerate(hosts):
|
||||
folders: List[FolderObservation] = host.folders
|
||||
for j, folder in enumerate(folders):
|
||||
assert folder.file_system_requires_scan == expected_val # Make sure folders require scan by default
|
||||
files: List[FileObservation] = folder.files
|
||||
for k, file in enumerate(files):
|
||||
assert file.file_system_requires_scan == expected_val
|
||||
|
||||
def test_file_require_scan(self):
|
||||
file_state = {"health_status": 3, "visible_status": 1}
|
||||
|
||||
obs_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=True)
|
||||
assert obs_requiring_scan.observe(file_state)["health_status"] == 1
|
||||
|
||||
obs_not_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=False)
|
||||
assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3
|
||||
|
||||
def test_folder_require_scan(self):
|
||||
folder_state = {"health_status": 3, "visible_status": 1}
|
||||
|
||||
obs_requiring_scan = FolderObservation(
|
||||
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
|
||||
)
|
||||
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
|
||||
|
||||
obs_not_requiring_scan = FolderObservation(
|
||||
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False
|
||||
)
|
||||
assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3
|
||||
Reference in New Issue
Block a user