#2417 Remove references to old obs names and add link obs
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -10,6 +10,8 @@ from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -165,7 +167,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation:
|
||||
"""
|
||||
Create an ACL observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -9,6 +9,8 @@ from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +75,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation:
|
||||
"""
|
||||
Create a file observation from a configuration schema.
|
||||
|
||||
@@ -172,7 +174,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
|
||||
"""
|
||||
Create a folder observation from a configuration schema.
|
||||
|
||||
@@ -190,5 +192,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -10,6 +10,8 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -190,7 +192,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> FirewallObservation:
|
||||
"""
|
||||
Create a firewall observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -12,6 +12,8 @@ from primaite.game.agent.observations.observations import AbstractObservation, W
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -184,7 +186,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation:
|
||||
"""
|
||||
Create a host observation from a configuration schema.
|
||||
|
||||
@@ -196,7 +198,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
:return: Constructed host observation instance.
|
||||
:rtype: HostObservation
|
||||
"""
|
||||
if parent_where is None:
|
||||
if parent_where == []:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
@@ -208,10 +210,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services]
|
||||
applications = [
|
||||
ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications
|
||||
]
|
||||
folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces]
|
||||
|
||||
return cls(
|
||||
where=where,
|
||||
|
||||
155
src/primaite/game/agent/observations/link_observation.py
Normal file
155
src/primaite/game/agent/observations/link_observation.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation, identifier="LINK"):
|
||||
"""Link observation, providing information about a specific link within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinkObservation."""
|
||||
|
||||
link_reference: str
|
||||
"""Reference identifier for the link."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a link observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this link.
|
||||
A typical location for a link might be ['network', 'links', <link_reference>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about the link.
|
||||
:rtype: Any
|
||||
"""
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for link status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation:
|
||||
"""
|
||||
Create a link observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the link observation.
|
||||
:type config: ConfigSchema
|
||||
:param game: The PrimaiteGame instance.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this link.
|
||||
A typical location might be ['network', 'links', <link_reference>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed link observation instance.
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
link_reference = game.ref_map_links[config.link_reference]
|
||||
if parent_where == []:
|
||||
where = ["network", "links", link_reference]
|
||||
else:
|
||||
where = parent_where + ["links", link_reference]
|
||||
return cls(where=where)
|
||||
|
||||
|
||||
class LinksObservation(AbstractObservation, identifier="LINKS"):
|
||||
"""Collection of link observations representing multiple links within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinksObservation."""
|
||||
|
||||
link_references: List[str]
|
||||
"""List of reference identifiers for the links."""
|
||||
|
||||
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
|
||||
"""
|
||||
Initialise a links observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for these links.
|
||||
A typical location for links might be ['network', 'links'].
|
||||
:type where: WhereType
|
||||
:param links: List of link observations.
|
||||
:type links: List[LinkObservation]
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.links: List[LinkObservation] = links
|
||||
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about multiple links.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for multiple links.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return {i + 1: l.space for i, l in enumerate(self.links)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
|
||||
"""
|
||||
Create a links observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the links observation.
|
||||
:type config: ConfigSchema
|
||||
:param game: The PrimaiteGame instance.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about these links.
|
||||
A typical location might be ['network'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed links observation instance.
|
||||
:rtype: LinksObservation
|
||||
"""
|
||||
where = parent_where + ["network"]
|
||||
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
|
||||
links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs]
|
||||
return cls(where=where, links=links)
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
@@ -82,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation:
|
||||
"""
|
||||
Create a network interface observation from a configuration schema.
|
||||
|
||||
@@ -142,7 +145,7 @@ class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation:
|
||||
"""
|
||||
Create a port observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -119,7 +121,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation:
|
||||
"""
|
||||
Create a nodes observation from a configuration schema.
|
||||
|
||||
@@ -178,8 +180,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
|
||||
hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls]
|
||||
|
||||
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
@@ -43,7 +43,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NestedObservation."""
|
||||
|
||||
components: List[NestedObservation.NestedObservationItem]
|
||||
components: List[NestedObservation.NestedObservationItem] = []
|
||||
"""List of observation components to be part of this space."""
|
||||
|
||||
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
|
||||
@@ -54,7 +54,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
|
||||
"""Default observation is just the default observations of constituents."""
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
@@ -76,7 +76,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema) -> NestedObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation:
|
||||
"""
|
||||
Read the Nested observation config and create all defined subcomponents.
|
||||
|
||||
@@ -115,7 +115,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options))
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
@@ -170,6 +170,6 @@ class ObservationManager:
|
||||
"""
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"]))
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Type
|
||||
from typing import Any, Dict, Iterable, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -8,8 +8,9 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
WhereType = Iterable[str | int] | None
|
||||
|
||||
|
||||
@@ -64,272 +65,8 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: ConfigSchema) -> "AbstractObservation":
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
|
||||
'''
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
'''
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -107,7 +109,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
|
||||
"""
|
||||
Create a router observation from a configuration schema.
|
||||
|
||||
@@ -137,6 +139,6 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
if config.ports is None:
|
||||
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
"""Service observation, shows status of a service in the simulation environment."""
|
||||
@@ -57,7 +60,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> ServiceObservation:
|
||||
"""
|
||||
Create a service observation from a configuration schema.
|
||||
|
||||
@@ -128,7 +133,9 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> ApplicationObservation:
|
||||
"""
|
||||
Create an application observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
|
||||
Reference in New Issue
Block a user