#2417 Remove references to old obs names and add link obs

This commit is contained in:
Marek Wolan
2024-03-31 21:39:24 +01:00
parent 15cb2e6970
commit 62ebca8c08
21 changed files with 247 additions and 331 deletions

View File

@@ -119,7 +119,7 @@ SessionManager.
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
- Integration of NMNE capturing functionality within the `NicObservation` class.
- Integration of NMNE capturing functionality within the `NICObservation` class.
- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario
### Removed

View File

@@ -73,7 +73,7 @@ Network Interface Classes
- Malicious Network Events Monitoring:
* Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NICObservation`` to classify and record network anomalies.
* Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms.
**WiredNetworkInterface (Connection Type Layer)**

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -10,6 +10,8 @@ from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -165,7 +167,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation:
"""
Create an ACL observation from a configuration schema.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -9,6 +9,8 @@ from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -73,7 +75,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
return spaces.Dict(space)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation:
"""
Create a file observation from a configuration schema.
@@ -172,7 +174,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
"""
Create a folder observation from a configuration schema.
@@ -190,5 +192,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
for file_config in config.files:
file_config.include_num_access = config.include_num_access
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -10,6 +10,8 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -190,7 +192,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> FirewallObservation:
"""
Create a firewall observation from a configuration schema.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -12,6 +12,8 @@ from primaite.game.agent.observations.observations import AbstractObservation, W
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -184,7 +186,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation:
"""
Create a host observation from a configuration schema.
@@ -196,7 +198,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:return: Constructed host observation instance.
:rtype: HostObservation
"""
if parent_where is None:
if parent_where == []:
where = ["network", "nodes", config.hostname]
else:
where = parent_where + ["nodes", config.hostname]
@@ -208,10 +210,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services]
applications = [
ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications
]
folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces]
return cls(
where=where,

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
from typing import Any, Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class LinkObservation(AbstractObservation, identifier="LINK"):
"""Link observation, providing information about a specific link within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinkObservation."""
link_reference: str
"""Reference identifier for the link."""
def __init__(self, where: WhereType) -> None:
"""
Initialise a link observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this link.
A typical location for a link might be ['network', 'links', <link_reference>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
def observe(self, state: Dict) -> Any:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about the link.
:rtype: Any
"""
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
utilisation_category = int(utilisation_fraction * 9) + 1
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for link status.
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation:
"""
Create a link observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the link observation.
:type config: ConfigSchema
:param game: The PrimaiteGame instance.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this link.
A typical location might be ['network', 'links', <link_reference>].
:type parent_where: WhereType, optional
:return: Constructed link observation instance.
:rtype: LinkObservation
"""
link_reference = game.ref_map_links[config.link_reference]
if parent_where == []:
where = ["network", "links", link_reference]
else:
where = parent_where + ["links", link_reference]
return cls(where=where)
class LinksObservation(AbstractObservation, identifier="LINKS"):
"""Collection of link observations representing multiple links within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinksObservation."""
link_references: List[str]
"""List of reference identifiers for the links."""
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
"""
Initialise a links observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for these links.
A typical location for links might be ['network', 'links'].
:type where: WhereType
:param links: List of link observations.
:type links: List[LinkObservation]
"""
self.where: WhereType = where
self.links: List[LinkObservation] = links
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about multiple links.
:rtype: ObsType
"""
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for multiple links.
:rtype: spaces.Space
"""
return {i + 1: l.space for i, l in enumerate(self.links)}
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
"""
Create a links observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the links observation.
:type config: ConfigSchema
:param game: The PrimaiteGame instance.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about these links.
A typical location might be ['network'].
:type parent_where: WhereType, optional
:return: Constructed links observation instance.
:rtype: LinksObservation
"""
where = parent_where + ["network"]
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs]
return cls(where=where, links=links)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
"""Status information about a network interface within the simulation environment."""
@@ -82,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation:
"""
Create a network interface observation from a configuration schema.
@@ -142,7 +145,7 @@ class PortObservation(AbstractObservation, identifier="PORT"):
return spaces.Dict({"operating_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation:
"""
Create a port observation from a configuration schema.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List
from typing import Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.router_observation import RouterObservation
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -119,7 +121,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation:
"""
Create a nodes observation from a configuration schema.
@@ -178,8 +180,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls]
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
@@ -43,7 +43,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NestedObservation."""
components: List[NestedObservation.NestedObservationItem]
components: List[NestedObservation.NestedObservationItem] = []
"""List of observation components to be part of this space."""
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
@@ -54,7 +54,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
"""Default observation is just the default observations of constituents."""
def observe(self, state: Dict) -> Any:
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -76,7 +76,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
@classmethod
def from_config(cls, config: ConfigSchema) -> NestedObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation:
"""
Read the Nested observation config and create all defined subcomponents.
@@ -115,7 +115,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options))
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game)
instances[component.label] = obs_instance
return cls(components=instances)
@@ -170,6 +170,6 @@ class ObservationManager:
"""
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"]))
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
obs_manager = cls(observation)
return obs_manager

View File

@@ -1,6 +1,6 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Type
from typing import Any, Dict, Iterable, Type, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -8,8 +8,9 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
WhereType = Iterable[str | int] | None
@@ -64,272 +65,8 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: ConfigSchema) -> "AbstractObservation":
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
'''
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass
'''

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -107,7 +109,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
"""
Create a router observation from a configuration schema.
@@ -137,6 +139,6 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
if config.ports is None:
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict
from typing import Dict, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
"""Service observation, shows status of a service in the simulation environment."""
@@ -57,7 +60,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"):
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> ServiceObservation:
"""
Create a service observation from a configuration schema.
@@ -128,7 +133,9 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> ApplicationObservation:
"""
Create an application observation from a configuration schema.

View File

@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False

View File

@@ -10,8 +10,7 @@ from _pytest.monkeypatch import MonkeyPatch
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations.observations import ICSObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
@@ -525,7 +524,7 @@ def game_and_agent():
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = ObservationManager(ICSObservation())
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
test_agent = ControlledAgent(

View File

@@ -1,6 +1,6 @@
import pytest
from primaite.game.agent.observations.observations import AclObservation
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
@@ -34,7 +34,7 @@ def test_acl_observations(simulation):
# add router acl rule
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
acl_obs = AclObservation(
acl_obs = ACLObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
node_ip_to_id={},
ports=["NTP", "HTTP", "POSTGRES_SERVER"],

View File

@@ -1,7 +1,7 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.observations import LinkObservation
from primaite.game.agent.observations.link_observation import LinkObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -5,7 +5,7 @@ import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -40,7 +40,7 @@ def test_nic(simulation):
nic: NIC = pc.network_interface[1]
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
@@ -61,13 +61,13 @@ def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
nic_obs = NicObservation(
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
@@ -80,7 +80,7 @@ def test_nic_categories(simulation):
with pytest.raises(Exception):
# should throw an error
NicObservation(
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
@@ -89,7 +89,7 @@ def test_nic_categories(simulation):
with pytest.raises(Exception):
# should throw an error
NicObservation(
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation
@@ -23,7 +23,7 @@ def test_node_observation(simulation):
"""Test a Node observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
node_obs = HostObservation(where=["network", "nodes", pc.hostname])
assert node_obs.space["operating_status"] == spaces.Discrete(5)

View File

@@ -1,4 +1,4 @@
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.sim_container import Simulation
@@ -141,9 +141,9 @@ def test_describe_state_nmne(uc2_network):
def test_capture_nmne_observations(uc2_network):
"""
Tests the NicObservation class's functionality within a simulated network environment.
Tests the NICObservation class's functionality within a simulated network environment.
This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the
This test ensures the observation space, as defined by instances of NICObservation, accurately reflects the
number of MNEs detected based on network activities over multiple iterations.
The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update
@@ -168,8 +168,8 @@ def test_capture_nmne_observations(uc2_network):
set_nmne_config(nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1])
web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1])
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1])
web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1])
# Iterate through a set of test cases to simulate multiple DELETE queries
for i in range(0, 20):

View File

@@ -1,6 +1,5 @@
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations.observations import ICSObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
@@ -52,7 +51,7 @@ def test_probabilistic_agent():
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
},
)
observation_space = ObservationManager(ICSObservation())
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
pa = ProbabilisticAgent(