From 526dcc7ffea26600e330d68e8681f1ed18c5020d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 3 Apr 2024 22:16:54 +0100 Subject: [PATCH] #2450 remove the need to pass Game to observation objects --- .../agent/observations/acl_observation.py | 6 ++---- .../observations/file_system_observations.py | 10 ++++------ .../observations/firewall_observation.py | 8 ++------ .../agent/observations/host_observations.py | 16 ++++++---------- .../agent/observations/link_observation.py | 14 ++++---------- .../agent/observations/nic_observations.py | 9 +++------ .../agent/observations/node_observations.py | 12 +++++------- .../agent/observations/observation_manager.py | 19 ++++++------------- .../game/agent/observations/observations.py | 8 ++------ .../agent/observations/router_observation.py | 10 ++++------ .../observations/software_observation.py | 13 +++---------- src/primaite/game/game.py | 6 +++--- 12 files changed, 44 insertions(+), 87 deletions(-) diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index fc603a8a..934d688e 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -1,7 +1,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -10,8 +10,6 @@ from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -167,7 +165,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): ) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: """ Create an ACL observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 3e262055..baf27660 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Iterable, List, Optional, TYPE_CHECKING +from typing import Dict, Iterable, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -9,8 +9,6 @@ from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -94,7 +92,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): return spaces.Dict(space) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: """ Create a file observation from a configuration schema. @@ -193,7 +191,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: """ Create a folder observation from a configuration schema. @@ -211,5 +209,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): for file_config in config.files: file_config.include_num_access = config.include_num_access - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files] + files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 0a1498b1..97a8f814 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -10,8 +10,6 @@ from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -200,9 +198,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): return space @classmethod - def from_config( - cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] - ) -> FirewallObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: """ Create a firewall observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 6dbde789..b15ede9a 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -12,8 +12,6 @@ from primaite.game.agent.observations.observations import AbstractObservation, W from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -201,7 +199,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> HostObservation: """ Create a host observation from a configuration schema. @@ -225,12 +223,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services] - applications = [ - ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications - ] - folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders] - nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces] + services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] + applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] + folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] + nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] return cls( where=where, diff --git a/src/primaite/game/agent/observations/link_observation.py b/src/primaite/game/agent/observations/link_observation.py index b55aae46..03a19fa0 100644 --- a/src/primaite/game/agent/observations/link_observation.py +++ b/src/primaite/game/agent/observations/link_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, Dict, List from gymnasium import spaces from gymnasium.core import ObsType @@ -9,8 +9,6 @@ from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -68,14 +66,12 @@ class LinkObservation(AbstractObservation, identifier="LINK"): return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinkObservation: """ Create a link observation from a configuration schema. :param config: Configuration schema containing the necessary information for the link observation. :type config: ConfigSchema - :param game: The PrimaiteGame instance. - :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this link. A typical location might be ['network', 'links', ]. :type parent_where: WhereType, optional @@ -135,14 +131,12 @@ class LinksObservation(AbstractObservation, identifier="LINKS"): return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)}) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinksObservation: """ Create a links observation from a configuration schema. :param config: Configuration schema containing the necessary information for the links observation. :type config: ConfigSchema - :param game: The PrimaiteGame instance. - :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about these links. A typical location might be ['network']. :type parent_where: WhereType, optional @@ -151,5 +145,5 @@ class LinksObservation(AbstractObservation, identifier="LINKS"): """ where = parent_where + ["network"] link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references] - links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs] + links = [LinkObservation.from_config(c, parent_where=where) for c in link_cfgs] return cls(where=where, links=links) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 44cc7f8f..afce9095 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Optional, TYPE_CHECKING +from typing import Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -8,9 +8,6 @@ from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" @@ -119,7 +116,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return space @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: """ Create a network interface observation from a configuration schema. @@ -179,7 +176,7 @@ class PortObservation(AbstractObservation, identifier="PORT"): return spaces.Dict({"operating_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: """ Create a port observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index f11ffebf..8f7ac0fc 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -12,8 +12,6 @@ from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -152,7 +150,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): return space @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation: """ Create a nodes observation from a configuration schema. @@ -211,8 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules - hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts] - routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers] - firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls] + hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] + routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] + firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls] return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 1d428fa8..047acce6 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -8,9 +8,6 @@ from pydantic import BaseModel, ConfigDict, model_validator, ValidationError from primaite.game.agent.observations.observations import AbstractObservation, WhereType -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - class NestedObservation(AbstractObservation, identifier="CUSTOM"): """Observation type that allows combining other observations into a gymnasium.spaces.Dict space.""" @@ -76,7 +73,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): return spaces.Dict({label: obs.space for label, obs in self.components.items()}) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NestedObservation: """ Read the Nested observation config and create all defined subcomponents. @@ -115,7 +112,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game) + obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options)) instances[component.label] = obs_instance return cls(components=instances) @@ -137,9 +134,7 @@ class NullObservation(AbstractObservation, identifier="NONE"): return spaces.Discrete(1) @classmethod - def from_config( - cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] - ) -> NullObservation: + def from_config(cls, config: NullObservation.ConfigSchema, parent_where: WhereType = []) -> NullObservation: """Instantiate a NullObservation. Accepts parameters to comply with API.""" return cls() @@ -180,7 +175,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager": + def from_config(cls, config: Optional[Dict]) -> "ObservationManager": """ Create observation space from a config. @@ -191,14 +186,12 @@ class ObservationManager: AbstractObservation options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame """ if config is None: return cls(NullObservation()) print(config) obs_type = config["type"] obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game) + observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"])) obs_manager = cls(observation) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 6c9db571..0d6ff2a3 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,6 +1,6 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Type, TYPE_CHECKING +from typing import Any, Dict, Iterable, Type from gymnasium import spaces from gymnasium.core import ObsType @@ -8,8 +8,6 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) WhereType = Iterable[str | int] | None @@ -65,8 +63,6 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config( - cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] - ) -> "AbstractObservation": + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index aeac2766..3f7e6494 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -11,8 +11,6 @@ from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -112,7 +110,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: """ Create a router observation from a configuration schema. @@ -142,6 +140,6 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): if config.ports is None: config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] - ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports] - acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where) + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] + acl = ACLObservation.from_config(config=config.acl, parent_where=where) return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 2c4806d9..f943f540 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, TYPE_CHECKING +from typing import Dict from gymnasium import spaces from gymnasium.core import ObsType @@ -8,9 +8,6 @@ from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - class ServiceObservation(AbstractObservation, identifier="SERVICE"): """Service observation, shows status of a service in the simulation environment.""" @@ -60,9 +57,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod - def from_config( - cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] - ) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: """ Create a service observation from a configuration schema. @@ -153,9 +148,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): ) @classmethod - def from_config( - cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] - ) -> ApplicationObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: """ Create an application observation from a configuration schema. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 034d11bc..2d007193 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -396,8 +396,8 @@ class PrimaiteGame: # 2. create links between nodes for link_cfg in links_cfg: - node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]] - node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]] + node_a = net.get_node_by_hostname(link_cfg["endpoint_a_ref"]) + node_b = net.get_node_by_hostname(link_cfg["endpoint_b_ref"]) if isinstance(node_a, Switch): endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]] else: @@ -419,7 +419,7 @@ class PrimaiteGame: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg, game) + obs_space = ObservationManager.from_config(observation_space_cfg) # CREATE ACTION SPACE action_space = ActionManager.from_config(game, action_space_cfg)