#2417 Remove references to old obs names and add link obs

This commit is contained in:
Marek Wolan
2024-03-31 21:39:24 +01:00
parent 15cb2e6970
commit 62ebca8c08
21 changed files with 247 additions and 331 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -10,6 +10,8 @@ from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -165,7 +167,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation:
"""
Create an ACL observation from a configuration schema.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -9,6 +9,8 @@ from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -73,7 +75,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
return spaces.Dict(space)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation:
"""
Create a file observation from a configuration schema.
@@ -172,7 +174,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
"""
Create a folder observation from a configuration schema.
@@ -190,5 +192,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
for file_config in config.files:
file_config.include_num_access = config.include_num_access
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -10,6 +10,8 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -190,7 +192,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> FirewallObservation:
"""
Create a firewall observation from a configuration schema.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -12,6 +12,8 @@ from primaite.game.agent.observations.observations import AbstractObservation, W
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -184,7 +186,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation:
"""
Create a host observation from a configuration schema.
@@ -196,7 +198,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:return: Constructed host observation instance.
:rtype: HostObservation
"""
if parent_where is None:
if parent_where == []:
where = ["network", "nodes", config.hostname]
else:
where = parent_where + ["nodes", config.hostname]
@@ -208,10 +210,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services]
applications = [
ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications
]
folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces]
return cls(
where=where,

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
from typing import Any, Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class LinkObservation(AbstractObservation, identifier="LINK"):
"""Link observation, providing information about a specific link within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinkObservation."""
link_reference: str
"""Reference identifier for the link."""
def __init__(self, where: WhereType) -> None:
"""
Initialise a link observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this link.
A typical location for a link might be ['network', 'links', <link_reference>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
def observe(self, state: Dict) -> Any:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about the link.
:rtype: Any
"""
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
utilisation_category = int(utilisation_fraction * 9) + 1
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for link status.
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation:
"""
Create a link observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the link observation.
:type config: ConfigSchema
:param game: The PrimaiteGame instance.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this link.
A typical location might be ['network', 'links', <link_reference>].
:type parent_where: WhereType, optional
:return: Constructed link observation instance.
:rtype: LinkObservation
"""
link_reference = game.ref_map_links[config.link_reference]
if parent_where == []:
where = ["network", "links", link_reference]
else:
where = parent_where + ["links", link_reference]
return cls(where=where)
class LinksObservation(AbstractObservation, identifier="LINKS"):
"""Collection of link observations representing multiple links within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinksObservation."""
link_references: List[str]
"""List of reference identifiers for the links."""
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
"""
Initialise a links observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for these links.
A typical location for links might be ['network', 'links'].
:type where: WhereType
:param links: List of link observations.
:type links: List[LinkObservation]
"""
self.where: WhereType = where
self.links: List[LinkObservation] = links
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about multiple links.
:rtype: ObsType
"""
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for multiple links.
:rtype: spaces.Space
"""
return {i + 1: l.space for i, l in enumerate(self.links)}
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
"""
Create a links observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the links observation.
:type config: ConfigSchema
:param game: The PrimaiteGame instance.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about these links.
A typical location might be ['network'].
:type parent_where: WhereType, optional
:return: Constructed links observation instance.
:rtype: LinksObservation
"""
where = parent_where + ["network"]
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs]
return cls(where=where, links=links)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
"""Status information about a network interface within the simulation environment."""
@@ -82,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation:
"""
Create a network interface observation from a configuration schema.
@@ -142,7 +145,7 @@ class PortObservation(AbstractObservation, identifier="PORT"):
return spaces.Dict({"operating_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation:
"""
Create a port observation from a configuration schema.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List
from typing import Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.router_observation import RouterObservation
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -119,7 +121,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation:
"""
Create a nodes observation from a configuration schema.
@@ -178,8 +180,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls]
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
@@ -43,7 +43,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NestedObservation."""
components: List[NestedObservation.NestedObservationItem]
components: List[NestedObservation.NestedObservationItem] = []
"""List of observation components to be part of this space."""
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
@@ -54,7 +54,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
"""Default observation is just the default observations of constituents."""
def observe(self, state: Dict) -> Any:
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -76,7 +76,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
@classmethod
def from_config(cls, config: ConfigSchema) -> NestedObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation:
"""
Read the Nested observation config and create all defined subcomponents.
@@ -115,7 +115,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options))
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game)
instances[component.label] = obs_instance
return cls(components=instances)
@@ -170,6 +170,6 @@ class ObservationManager:
"""
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"]))
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
obs_manager = cls(observation)
return obs_manager

View File

@@ -1,6 +1,6 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Type
from typing import Any, Dict, Iterable, Type, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -8,8 +8,9 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
WhereType = Iterable[str | int] | None
@@ -64,272 +65,8 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: ConfigSchema) -> "AbstractObservation":
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
'''
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass
'''

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
@@ -107,7 +109,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
"""
Create a router observation from a configuration schema.
@@ -137,6 +139,6 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
if config.ports is None:
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict
from typing import Dict, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
"""Service observation, shows status of a service in the simulation environment."""
@@ -57,7 +60,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"):
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> ServiceObservation:
"""
Create a service observation from a configuration schema.
@@ -128,7 +133,9 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> ApplicationObservation:
"""
Create an application observation from a configuration schema.

View File

@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False