Merge remote-tracking branch 'origin/dev' into feature/2769-implement-user-account-action-space

This commit is contained in:
Marek Wolan
2024-08-13 13:15:47 +01:00
8 changed files with 213 additions and 8 deletions

View File

@@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
main port they're assigned.
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
### Fixed
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
- Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install`
and `uninstall` methods in the `Node` class.
- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the

View File

@@ -23,8 +23,10 @@ class FileObservation(AbstractObservation, identifier="FILE"):
"""Name of the file, used for querying simulation state dictionary."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to the file in the observation."""
file_system_requires_scan: Optional[bool] = None
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(self, where: WhereType, include_num_access: bool) -> None:
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
"""
Initialise a file observation instance.
@@ -34,9 +36,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:type where: WhereType
:param include_num_access: Whether to include the number of accesses to the file in the observation.
:type include_num_access: bool
:param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
self.include_num_access: bool = include_num_access
self.file_system_requires_scan: bool = file_system_requires_scan
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
@@ -74,7 +80,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {"health_status": file_state["visible_status"]}
if self.file_system_requires_scan:
health_status = file_state["visible_status"]
else:
health_status = file_state["health_status"]
obs = {"health_status": health_status}
if self.include_num_access:
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@@ -104,8 +114,15 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:type parent_where: WhereType, optional
:return: Constructed file observation instance.
:rtype: FileObservation
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
return cls(
where=parent_where + ["files", config.file_name],
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)
class FolderObservation(AbstractObservation, identifier="FOLDER"):
@@ -122,9 +139,16 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
"""Number of spaces for file observations in this folder."""
include_num_access: Optional[bool] = None
"""Whether files in this folder should include the number of accesses in their observation."""
file_system_requires_scan: Optional[bool] = None
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
self,
where: WhereType,
files: Iterable[FileObservation],
num_files: int,
include_num_access: bool,
file_system_requires_scan: bool,
) -> None:
"""
Initialise a folder observation instance.
@@ -138,12 +162,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
:type num_files: int
:param include_num_access: Whether to include the number of accesses to files in the observation.
:type include_num_access: bool
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
self.file_system_requires_scan: bool = file_system_requires_scan
self.files: List[FileObservation] = files
while len(self.files) < num_files:
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
self.files.append(
FileObservation(
where=None,
include_num_access=include_num_access,
file_system_requires_scan=self.file_system_requires_scan,
)
)
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
@@ -168,7 +203,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
if self.file_system_requires_scan:
health_status = folder_state["visible_status"]
else:
health_status = folder_state["health_status"]
obs = {}
@@ -209,6 +247,13 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
# pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
file_config.file_system_requires_scan = config.file_system_requires_scan
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
return cls(
where=where,
files=files,
num_files=config.num_files,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)

View File

@@ -48,6 +48,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""A dict containing which traffic types are to be included in the observation."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to files observations on this host."""
file_system_requires_scan: Optional[bool] = None
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
def __init__(
self,
@@ -64,6 +68,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_nmne: bool,
monitored_traffic: Optional[Dict],
include_num_access: bool,
file_system_requires_scan: bool,
) -> None:
"""
Initialise a host observation instance.
@@ -95,6 +100,9 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:type monitored_traffic: Dict
:param include_num_access: Flag to include the number of accesses to files.
:type include_num_access: bool
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
If False, the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
@@ -120,7 +128,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
FolderObservation(
where=None,
files=[],
num_files=num_files,
include_num_access=include_num_access,
file_system_requires_scan=file_system_requires_scan,
)
)
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
@@ -226,6 +240,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
folder_config.file_system_requires_scan = config.file_system_requires_scan
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
@@ -257,4 +272,5 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)

View File

@@ -44,6 +44,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""A dict containing which traffic types are to be included in the observation."""
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
file_system_requires_scan: bool = True
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
@@ -187,6 +189,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
host_config.monitored_traffic = config.monitored_traffic
if host_config.include_num_access is None:
host_config.include_num_access = config.include_num_access
if host_config.file_system_requires_scan is None:
host_config.file_system_requires_scan = config.file_system_requires_scan
for router_config in config.routers:
if router_config.num_ports is None:

View File

@@ -26,6 +26,7 @@ def test_file_observation(simulation):
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
@@ -53,6 +54,7 @@ def test_folder_observation(simulation):
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
include_num_access=False,
file_system_requires_scan=True,
num_files=1,
files=[],
)

View File

@@ -38,6 +38,7 @@ def test_host_observation(simulation):
applications=[],
folders=[],
network_interfaces=[],
file_system_requires_scan=True,
)
assert host_obs.space["operating_status"] == spaces.Discrete(5)

View File

@@ -17,6 +17,7 @@ def test_file_observation():
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})

View File

@@ -0,0 +1,132 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import List
import pytest
import yaml
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.game.agent.observations.host_observations import HostObservation
class TestFileSystemRequiresScan:
@pytest.mark.parametrize(
("yaml_option_string", "expected_val"),
(
("file_system_requires_scan: true", True),
("file_system_requires_scan: false", False),
(" ", True),
),
)
def test_obs_config(self, yaml_option_string, expected_val):
"""Check that the default behaviour is to set FileSystemRequiresScan to True."""
obs_cfg_yaml = f"""
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
{yaml_option_string}
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {{}}
"""
cfg = yaml.safe_load(obs_cfg_yaml)
manager = ObservationManager.from_config(cfg)
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
for i, host in enumerate(hosts):
folders: List[FolderObservation] = host.folders
for j, folder in enumerate(folders):
assert folder.file_system_requires_scan == expected_val # Make sure folders require scan by default
files: List[FileObservation] = folder.files
for k, file in enumerate(files):
assert file.file_system_requires_scan == expected_val
def test_file_require_scan(self):
file_state = {"health_status": 3, "visible_status": 1}
obs_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=True)
assert obs_requiring_scan.observe(file_state)["health_status"] == 1
obs_not_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=False)
assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3
def test_folder_require_scan(self):
folder_state = {"health_status": 3, "visible_status": 1}
obs_requiring_scan = FolderObservation(
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
)
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
obs_not_requiring_scan = FolderObservation(
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False
)
assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3