New observations
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.host import NodeObservation
|
||||
from primaite.game.agent.observations.observations import (
|
||||
AbstractObservation,
|
||||
AclObservation,
|
||||
@@ -136,53 +136,3 @@ class UC2BlueObservation(AbstractObservation):
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
"""Container for all observations used by the red agent in UC2."""
|
||||
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation):
|
||||
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,21 +1,473 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
# from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
# from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
# from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
WhereType = Iterable[str | int] | None
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
service_name: str
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": service_state["operating_state"],
|
||||
"health_status": service_state["health_state_visible"],
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
return cls(where=parent_where+["services", config.service_name])
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
application_name: str
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
# raise NotImplementedError("TODO NUM EXECUTIONS NEEDS TO BE CONVERTED TO A CATEGORICAL")
|
||||
application_state = access_from_nested_dict(state, self.where)
|
||||
if application_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"num_executions": application_state["num_executions"],
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({
|
||||
"operating_status": spaces.Discrete(7),
|
||||
"health_status": spaces.Discrete(5),
|
||||
"num_executions": spaces.Discrete(4)
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ApplicationObservation:
|
||||
return cls(where=parent_where+["applications", config.application_name])
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
file_name: str
|
||||
include_num_access : bool = False
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool)->None:
|
||||
self.where: WhereType = where
|
||||
self.include_num_access :bool = include_num_access
|
||||
|
||||
self.default_observation: ObsType = {"health_status": 0}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {"health_status": file_state["visible_status"]}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = file_state["num_access"]
|
||||
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
space = {"health_status": spaces.Discrete(6)}
|
||||
if self.include_num_access:
|
||||
space["num_access"] = spaces.Discrete(4)
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation:
|
||||
return cls(where=parent_where+["files", config.file_name], include_num_access=config.include_num_access)
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
folder_name: str
|
||||
files: List[FileObservation.ConfigSchema] = []
|
||||
num_files : int = 0
|
||||
include_num_access : bool = False
|
||||
|
||||
def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None:
|
||||
self.where: WhereType = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(FileObservation(where=None,include_num_access=include_num_access))
|
||||
while len(self.files) > num_files:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation:
|
||||
where = parent_where + ["folders", config.folder_name]
|
||||
|
||||
#pass down shared/common config items
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where = where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
|
||||
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
nic_num: int
|
||||
include_nmne: bool = False
|
||||
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool)->None:
|
||||
self.where = where
|
||||
self.include_nmne : bool = include_nmne
|
||||
|
||||
self.default_observation: ObsType = {"nic_status": 0}
|
||||
if self.include_nmne:
|
||||
self.default_observation.update({"NMNE":{"inbound":0, "outbound":0}})
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
# raise NotImplementedError("TODO: CATEGORISATION")
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs
|
||||
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if self.include_nmne:
|
||||
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne)
|
||||
|
||||
|
||||
class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
services: List[ServiceObservation.ConfigSchema] = []
|
||||
applications: List[ApplicationObservation.ConfigSchema] = []
|
||||
folders: List[FolderObservation.ConfigSchema] = []
|
||||
network_interfaces: List[NICObservation.ConfigSchema] = []
|
||||
num_services: int
|
||||
num_applications: int
|
||||
num_folders: int
|
||||
num_files: int
|
||||
num_nics: int
|
||||
include_nmne: bool
|
||||
include_num_access: bool
|
||||
|
||||
def __init__(self,
|
||||
where: WhereType,
|
||||
services:List[ServiceObservation],
|
||||
applications:List[ApplicationObservation],
|
||||
folders:List[FolderObservation],
|
||||
network_interfaces:List[NICObservation],
|
||||
num_services: int,
|
||||
num_applications: int,
|
||||
num_folders: int,
|
||||
num_files: int,
|
||||
num_nics: int,
|
||||
include_nmne: bool,
|
||||
include_num_access: bool
|
||||
)->None:
|
||||
|
||||
self.where : WhereType = where
|
||||
|
||||
# ensure service list has length equal to num_services by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
self.services.append(ServiceObservation(where=None))
|
||||
while len(self.services) > num_services:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
# ensure application list has length equal to num_applications by truncating or padding
|
||||
self.applications: List[ApplicationObservation] = applications
|
||||
while len(self.applications) < num_applications:
|
||||
self.applications.append(ApplicationObservation(where=None))
|
||||
while len(self.applications) > num_applications:
|
||||
truncated_application = self.applications.pop()
|
||||
msg = f"Too many applications in Node observation space for node. Truncating application {truncated_application.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
# ensure folder list has length equal to num_folders by truncating or padding
|
||||
self.folders: List[FolderObservation] = folders
|
||||
while len(self.folders) < num_folders:
|
||||
self.folders.append(FolderObservation(where = None, files= [], num_files=num_files, include_num_access=include_num_access))
|
||||
while len(self.folders) > num_folders:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
# ensure network_interface list has length equal to num_network_interfaces by truncating or padding
|
||||
self.network_interfaces: List[NICObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics:
|
||||
self.network_interfaces.append(NICObservation(where = None, include_nmne=include_nmne))
|
||||
while len(self.network_interfaces) > num_nics:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_folder.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation: ObsType = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
"num_file_creations": 0,
|
||||
"num_file_deletions": 0,
|
||||
}
|
||||
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
}
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
"num_file_creations" : spaces.Discrete(4),
|
||||
"num_file_deletions" : spaces.Discrete(4),
|
||||
}
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation:
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
#pass down shared/common config items
|
||||
for folder_config in config.folders:
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
services = [ServiceObservation.from_config(config=c,parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
|
||||
|
||||
return cls(
|
||||
where = where,
|
||||
services = services,
|
||||
applications = applications,
|
||||
folders = folders,
|
||||
network_interfaces = nics,
|
||||
num_services = config.num_services,
|
||||
num_applications = config.num_applications,
|
||||
num_folders = config.num_folders,
|
||||
num_files = config.num_files,
|
||||
num_nics = config.num_nics,
|
||||
include_nmne = config.include_nmne,
|
||||
include_num_access = config.include_num_access,
|
||||
)
|
||||
|
||||
|
||||
class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
pass
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
|
||||
class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
pass
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
|
||||
class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
ports: List[PortObservation.ConfigSchema]
|
||||
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
|
||||
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
ports: List[PortObservation.ConfigSchema] = []
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
|
||||
class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Config"""
|
||||
hosts: List[HostObservation.ConfigSchema] = []
|
||||
routers: List[RouterObservation.ConfigSchema] = []
|
||||
firewalls: List[FirewallObservation.ConfigSchema] = []
|
||||
num_services: int = 1
|
||||
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
|
||||
############################ OLD
|
||||
|
||||
class NodeObservation(AbstractObservation, identifier= "OLD"):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -2,11 +2,6 @@ from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.agent_observations import (
|
||||
UC2BlueObservation,
|
||||
UC2GreenObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -63,11 +58,10 @@ class ObservationManager:
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
|
||||
for obs_cfg in config:
|
||||
obs_type = obs_cfg['type']
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options']))
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -17,6 +18,28 @@ if TYPE_CHECKING:
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
"""Registry of observation components, with their name as key.
|
||||
|
||||
Automatically populated when subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
|
||||
:param identifier: Identifier used to uniquely specify observation component types.
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to create a component with a name that is already in use.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate observation component type {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
@@ -36,274 +59,271 @@ class AbstractObservation(ABC):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component.
|
||||
"""
|
||||
pass
|
||||
def from_config(cls, cfg: Dict) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
ObservationType = cls._registry[cfg['type']]
|
||||
return ObservationType.from_config(cfg=cfg)
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
# class LinkObservation(AbstractObservation):
|
||||
# """Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
# "Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
# def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
# """Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
# :param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
# Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
# zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
# A typical location for a service looks like this:
|
||||
# `['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
# :type where: Optional[List[str]]
|
||||
# """
|
||||
# super().__init__()
|
||||
# self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
# def observe(self, state: Dict) -> Dict:
|
||||
# """Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
# :param state: Simulation state dictionary
|
||||
# :type state: Dict
|
||||
# :return: Observation
|
||||
# :rtype: Dict
|
||||
# """
|
||||
# if self.where is None:
|
||||
# return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
# link_state = access_from_nested_dict(state, self.where)
|
||||
# if link_state is NOT_PRESENT_IN_STATE:
|
||||
# return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
# bandwidth = link_state["bandwidth"]
|
||||
# load = link_state["current_load"]
|
||||
# if load == 0:
|
||||
# utilisation_category = 0
|
||||
# else:
|
||||
# utilisation_fraction = load / bandwidth
|
||||
# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
# utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
# # TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
# @property
|
||||
# def space(self) -> spaces.Space:
|
||||
# """Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
# :return: Gymnasium space
|
||||
# :rtype: spaces.Space
|
||||
# """
|
||||
# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
# @classmethod
|
||||
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
# """Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
# :param config: Dictionary containing the configuration for this link observation.
|
||||
# :type config: Dict
|
||||
# :param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
# :type game: PrimaiteGame
|
||||
# :return: Constructed link observation
|
||||
# :rtype: LinkObservation
|
||||
# """
|
||||
# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
# class AclObservation(AbstractObservation):
|
||||
# """Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
# # TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# # definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# # if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# # this needs adding, but not for the MVP.
|
||||
# def __init__(
|
||||
# self,
|
||||
# node_ip_to_id: Dict[str, int],
|
||||
# ports: List[int],
|
||||
# protocols: List[str],
|
||||
# where: Optional[Tuple[str]] = None,
|
||||
# num_rules: int = 10,
|
||||
# ) -> None:
|
||||
# """Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
# :param node_ip_to_id: Mapping between IP address and ID.
|
||||
# :type node_ip_to_id: Dict[str, int]
|
||||
# :param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
# :type ports: List[int]
|
||||
# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
# :type protocols: list[str]
|
||||
# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
# example may look like this:
|
||||
# ['network','nodes',<router_hostname>,'acl','acl']
|
||||
# :type where: Optional[Tuple[str]], optional
|
||||
# :param num_rules: , defaults to 10
|
||||
# :type num_rules: int, optional
|
||||
# """
|
||||
# super().__init__()
|
||||
# self.where: Optional[Tuple[str]] = where
|
||||
# self.num_rules: int = num_rules
|
||||
# self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
# "List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
# "List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
# "List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
# self.default_observation: Dict = {
|
||||
# i
|
||||
# + 1: {
|
||||
# "position": i,
|
||||
# "permission": 0,
|
||||
# "source_node_id": 0,
|
||||
# "source_port": 0,
|
||||
# "dest_node_id": 0,
|
||||
# "dest_port": 0,
|
||||
# "protocol": 0,
|
||||
# }
|
||||
# for i in range(self.num_rules)
|
||||
# }
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
# def observe(self, state: Dict) -> Dict:
|
||||
# """Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
# :param state: Simulation state dictionary
|
||||
# :type state: Dict
|
||||
# :return: Observation
|
||||
# :rtype: Dict
|
||||
# """
|
||||
# if self.where is None:
|
||||
# return self.default_observation
|
||||
# acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
# if acl_state is NOT_PRESENT_IN_STATE:
|
||||
# return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
# # TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
# obs = {}
|
||||
# acl_items = dict(acl_state.items())
|
||||
# i = 1 # don't show rule 0 for compatibility reasons.
|
||||
# while i < self.num_rules + 1:
|
||||
# rule_state = acl_items[i]
|
||||
# if rule_state is None:
|
||||
# obs[i] = {
|
||||
# "position": i - 1,
|
||||
# "permission": 0,
|
||||
# "source_node_id": 0,
|
||||
# "source_port": 0,
|
||||
# "dest_node_id": 0,
|
||||
# "dest_port": 0,
|
||||
# "protocol": 0,
|
||||
# }
|
||||
# else:
|
||||
# src_ip = rule_state["src_ip_address"]
|
||||
# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
# dst_ip = rule_state["dst_ip_address"]
|
||||
# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
# src_port = rule_state["src_port"]
|
||||
# src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
# dst_port = rule_state["dst_port"]
|
||||
# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
# protocol = rule_state["protocol"]
|
||||
# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
# obs[i] = {
|
||||
# "position": i - 1,
|
||||
# "permission": rule_state["action"],
|
||||
# "source_node_id": src_node_id,
|
||||
# "source_port": src_port_id,
|
||||
# "dest_node_id": dst_node_ip,
|
||||
# "dest_port": dst_port_id,
|
||||
# "protocol": protocol_id,
|
||||
# }
|
||||
# i += 1
|
||||
# return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
# @property
|
||||
# def space(self) -> spaces.Space:
|
||||
# """Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
# :return: Gymnasium space
|
||||
# :rtype: spaces.Space
|
||||
# """
|
||||
# return spaces.Dict(
|
||||
# {
|
||||
# i
|
||||
# + 1: spaces.Dict(
|
||||
# {
|
||||
# "position": spaces.Discrete(self.num_rules),
|
||||
# "permission": spaces.Discrete(3),
|
||||
# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
# "source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
# "dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
# }
|
||||
# )
|
||||
# for i in range(self.num_rules)
|
||||
# }
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
# @classmethod
|
||||
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
# """Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
# :param config: Dictionary containing the configuration for this ACL observation.
|
||||
# :type config: Dict
|
||||
# :param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
# :type game: PrimaiteGame
|
||||
# :return: Observation object
|
||||
# :rtype: AclObservation
|
||||
# """
|
||||
# max_acl_rules = config["options"]["max_acl_rules"]
|
||||
# node_ip_to_idx = {}
|
||||
# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
# node_ref = ip_map_config["node_hostname"]
|
||||
# nic_num = ip_map_config["nic_num"]
|
||||
# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
# nic_obj = node_obj.network_interface[nic_num]
|
||||
# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
# router_hostname = config["router_hostname"]
|
||||
# return cls(
|
||||
# node_ip_to_id=node_ip_to_idx,
|
||||
# ports=game.options.ports,
|
||||
# protocols=game.options.protocols,
|
||||
# where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
# num_rules=max_acl_rules,
|
||||
# )
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
# class NullObservation(AbstractObservation):
|
||||
# """Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
# def __init__(self, where: Optional[List[str]] = None):
|
||||
# """Initialise null observation."""
|
||||
# self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
# def observe(self, state: Dict) -> Dict:
|
||||
# """Generate observation based on the current state of the simulation."""
|
||||
# return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
# @property
|
||||
# def space(self) -> spaces.Space:
|
||||
# """Gymnasium space object describing the observation space shape."""
|
||||
# return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
# @classmethod
|
||||
# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
# """
|
||||
# Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
# The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
# """
|
||||
# return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
# class ICSObservation(NullObservation):
|
||||
# """ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
# pass
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Hashable, Sequence
|
||||
from typing import Any, Dict, Hashable, Optional, Sequence
|
||||
|
||||
NOT_PRESENT_IN_STATE = object()
|
||||
"""
|
||||
@@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any:
|
||||
"""
|
||||
Access an item from a deeply dictionary with a list of keys.
|
||||
|
||||
@@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
if keys is None:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.host import NodeObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user