#2417 Remove references to old obs names and add link obs
This commit is contained in:
@@ -119,7 +119,7 @@ SessionManager.
|
||||
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
|
||||
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
|
||||
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
|
||||
- Integration of NMNE capturing functionality within the `NicObservation` class.
|
||||
- Integration of NMNE capturing functionality within the `NICObservation` class.
|
||||
- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario
|
||||
|
||||
### Removed
|
||||
|
||||
@@ -73,7 +73,7 @@ Network Interface Classes
|
||||
- Malicious Network Events Monitoring:
|
||||
|
||||
* Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns.
|
||||
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies.
|
||||
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NICObservation`` to classify and record network anomalies.
|
||||
* Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms.
|
||||
|
||||
**WiredNetworkInterface (Connection Type Layer)**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -10,6 +10,8 @@ from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -165,7 +167,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation:
|
||||
"""
|
||||
Create an ACL observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -9,6 +9,8 @@ from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +75,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation:
|
||||
"""
|
||||
Create a file observation from a configuration schema.
|
||||
|
||||
@@ -172,7 +174,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
|
||||
"""
|
||||
Create a folder observation from a configuration schema.
|
||||
|
||||
@@ -190,5 +192,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -10,6 +10,8 @@ from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -190,7 +192,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> FirewallObservation:
|
||||
"""
|
||||
Create a firewall observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -12,6 +12,8 @@ from primaite.game.agent.observations.observations import AbstractObservation, W
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -184,7 +186,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation:
|
||||
"""
|
||||
Create a host observation from a configuration schema.
|
||||
|
||||
@@ -196,7 +198,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
:return: Constructed host observation instance.
|
||||
:rtype: HostObservation
|
||||
"""
|
||||
if parent_where is None:
|
||||
if parent_where == []:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
@@ -208,10 +210,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services]
|
||||
applications = [
|
||||
ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications
|
||||
]
|
||||
folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces]
|
||||
|
||||
return cls(
|
||||
where=where,
|
||||
|
||||
155
src/primaite/game/agent/observations/link_observation.py
Normal file
155
src/primaite/game/agent/observations/link_observation.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation, identifier="LINK"):
|
||||
"""Link observation, providing information about a specific link within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinkObservation."""
|
||||
|
||||
link_reference: str
|
||||
"""Reference identifier for the link."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a link observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this link.
|
||||
A typical location for a link might be ['network', 'links', <link_reference>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about the link.
|
||||
:rtype: Any
|
||||
"""
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for link status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation:
|
||||
"""
|
||||
Create a link observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the link observation.
|
||||
:type config: ConfigSchema
|
||||
:param game: The PrimaiteGame instance.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this link.
|
||||
A typical location might be ['network', 'links', <link_reference>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed link observation instance.
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
link_reference = game.ref_map_links[config.link_reference]
|
||||
if parent_where == []:
|
||||
where = ["network", "links", link_reference]
|
||||
else:
|
||||
where = parent_where + ["links", link_reference]
|
||||
return cls(where=where)
|
||||
|
||||
|
||||
class LinksObservation(AbstractObservation, identifier="LINKS"):
|
||||
"""Collection of link observations representing multiple links within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinksObservation."""
|
||||
|
||||
link_references: List[str]
|
||||
"""List of reference identifiers for the links."""
|
||||
|
||||
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
|
||||
"""
|
||||
Initialise a links observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for these links.
|
||||
A typical location for links might be ['network', 'links'].
|
||||
:type where: WhereType
|
||||
:param links: List of link observations.
|
||||
:type links: List[LinkObservation]
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.links: List[LinkObservation] = links
|
||||
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about multiple links.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for multiple links.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return {i + 1: l.space for i, l in enumerate(self.links)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
|
||||
"""
|
||||
Create a links observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the links observation.
|
||||
:type config: ConfigSchema
|
||||
:param game: The PrimaiteGame instance.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about these links.
|
||||
A typical location might be ['network'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed links observation instance.
|
||||
:rtype: LinksObservation
|
||||
"""
|
||||
where = parent_where + ["network"]
|
||||
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
|
||||
links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs]
|
||||
return cls(where=where, links=links)
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
@@ -82,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation:
|
||||
"""
|
||||
Create a network interface observation from a configuration schema.
|
||||
|
||||
@@ -142,7 +145,7 @@ class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation:
|
||||
"""
|
||||
Create a port observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -119,7 +121,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation:
|
||||
"""
|
||||
Create a nodes observation from a configuration schema.
|
||||
|
||||
@@ -178,8 +180,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
|
||||
hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls]
|
||||
|
||||
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
@@ -43,7 +43,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NestedObservation."""
|
||||
|
||||
components: List[NestedObservation.NestedObservationItem]
|
||||
components: List[NestedObservation.NestedObservationItem] = []
|
||||
"""List of observation components to be part of this space."""
|
||||
|
||||
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
|
||||
@@ -54,7 +54,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
|
||||
"""Default observation is just the default observations of constituents."""
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
@@ -76,7 +76,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema) -> NestedObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation:
|
||||
"""
|
||||
Read the Nested observation config and create all defined subcomponents.
|
||||
|
||||
@@ -115,7 +115,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options))
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
@@ -170,6 +170,6 @@ class ObservationManager:
|
||||
"""
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"]))
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Type
|
||||
from typing import Any, Dict, Iterable, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -8,8 +8,9 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
WhereType = Iterable[str | int] | None
|
||||
|
||||
|
||||
@@ -64,272 +65,8 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: ConfigSchema) -> "AbstractObservation":
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
|
||||
'''
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
'''
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -11,6 +11,8 @@ from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -107,7 +109,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
|
||||
"""
|
||||
Create a router observation from a configuration schema.
|
||||
|
||||
@@ -137,6 +139,6 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
if config.ports is None:
|
||||
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -8,6 +8,9 @@ from gymnasium.core import ObsType
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
"""Service observation, shows status of a service in the simulation environment."""
|
||||
@@ -57,7 +60,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> ServiceObservation:
|
||||
"""
|
||||
Create a service observation from a configuration schema.
|
||||
|
||||
@@ -128,7 +133,9 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> ApplicationObservation:
|
||||
"""
|
||||
Create an application observation from a configuration schema.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
|
||||
@@ -10,8 +10,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations.observations import ICSObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import PrimaiteSession
|
||||
@@ -525,7 +524,7 @@ def game_and_agent():
|
||||
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
|
||||
act_map={},
|
||||
)
|
||||
observation_space = ObservationManager(ICSObservation())
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
test_agent = ControlledAgent(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.observations.observations import AclObservation
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -34,7 +34,7 @@ def test_acl_observations(simulation):
|
||||
# add router acl rule
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
|
||||
|
||||
acl_obs = AclObservation(
|
||||
acl_obs = ACLObservation(
|
||||
where=["network", "nodes", router.hostname, "acl", "acl"],
|
||||
node_ip_to_id={},
|
||||
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import LinkObservation
|
||||
from primaite.game.agent.observations.link_observation import LinkObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
import yaml
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -40,7 +40,7 @@ def test_nic(simulation):
|
||||
|
||||
nic: NIC = pc.network_interface[1]
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
|
||||
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
|
||||
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
|
||||
@@ -61,13 +61,13 @@ def test_nic_categories(simulation):
|
||||
"""Test the NIC observation nmne count categories."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 10 # default
|
||||
assert nic_obs.med_nmne_threshold == 5 # default
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
nic_obs = NicObservation(
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
@@ -80,7 +80,7 @@ def test_nic_categories(simulation):
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
@@ -89,7 +89,7 @@ def test_nic_categories(simulation):
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
@@ -23,7 +23,7 @@ def test_node_observation(simulation):
|
||||
"""Test a Node observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
|
||||
node_obs = HostObservation(where=["network", "nodes", pc.hostname])
|
||||
|
||||
assert node_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
@@ -141,9 +141,9 @@ def test_describe_state_nmne(uc2_network):
|
||||
|
||||
def test_capture_nmne_observations(uc2_network):
|
||||
"""
|
||||
Tests the NicObservation class's functionality within a simulated network environment.
|
||||
Tests the NICObservation class's functionality within a simulated network environment.
|
||||
|
||||
This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the
|
||||
This test ensures the observation space, as defined by instances of NICObservation, accurately reflects the
|
||||
number of MNEs detected based on network activities over multiple iterations.
|
||||
|
||||
The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update
|
||||
@@ -168,8 +168,8 @@ def test_capture_nmne_observations(uc2_network):
|
||||
set_nmne_config(nmne_config)
|
||||
|
||||
# Define observations for the NICs of the database and web servers
|
||||
db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1])
|
||||
web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1])
|
||||
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1])
|
||||
web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1])
|
||||
|
||||
# Iterate through a set of test cases to simulate multiple DELETE queries
|
||||
for i in range(0, 20):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations.observations import ICSObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
|
||||
@@ -52,7 +51,7 @@ def test_probabilistic_agent():
|
||||
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
|
||||
},
|
||||
)
|
||||
observation_space = ObservationManager(ICSObservation())
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
pa = ProbabilisticAgent(
|
||||
|
||||
Reference in New Issue
Block a user