diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index 6022f697..ee5ed292 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -2,15 +2,13 @@ from __future__ import annotations from abc import ABC -from typing import List - -from pydantic import field_validator +from typing import List, Literal, Union from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat -from primaite.utils.validation.ip_protocol import protocol_validator -from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address -from primaite.utils.validation.port import port_validator +from primaite.utils.validation.ip_protocol import IPProtocol +from primaite.utils.validation.ipv4_address import IPV4Address +from primaite.utils.validation.port import Port __all__ = ( "RouterACLAddRuleAction", @@ -29,43 +27,14 @@ class ACLAddRuleAbstractAction(AbstractAction, ABC): """Configuration Schema base for ACL add rule abstract actions.""" src_ip: IPV4Address - protocol_name: str - permission: str + protocol_name: Union[IPProtocol, Literal["ALL"]] + permission: Literal["ALLOW", "DENY"] position: int - dst_ip: IPV4Address - src_port: int - dst_port: int - src_wildcard: int - dst_wildcard: int - - @field_validator( - "src_port", - "dst_port", - mode="before", - ) - @classmethod - def valid_port(cls, v: str) -> int: - """Check that inputs are valid.""" - return port_validator(v) - - @field_validator( - "src_ip", - "dst_ip", - mode="before", - ) - @classmethod - def valid_ip(cls, v: str) -> str: - """Check that a valid IP has been provided for src and dst.""" - return ipv4_validator(v) - - @field_validator( - "protocol_name", - mode="before", - ) - @classmethod - def is_valid_protocol(cls, v: str) -> bool: - """Check that we are using a valid protocol.""" - return protocol_validator(v) + dst_ip: Union[IPV4Address, Literal["ALL"]] + src_port: Union[Port, Literal["ALL"]] + dst_port: Union[Port, Literal["ALL"]] + src_wildcard: Union[IPV4Address, Literal["NONE"]] + dst_wildcard: Union[IPV4Address, Literal["NONE"]] class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"): @@ -100,10 +69,10 @@ class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_ad "add_rule", config.permission, config.protocol_name, - config.src_ip, + str(config.src_ip), config.src_wildcard, config.src_port, - config.dst_ip, + str(config.dst_ip), config.dst_wildcard, config.dst_port, config.position, @@ -139,7 +108,7 @@ class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_ac firewall_port_direction: str @classmethod - def form_request(cls, config: ConfigSchema) -> List[str]: + def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return [ "network", @@ -151,10 +120,10 @@ class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_ac "add_rule", config.permission, config.protocol_name, - config.src_ip, + str(config.src_ip), config.src_wildcard, config.src_port, - config.dst_ip, + str(config.dst_ip), config.dst_wildcard, config.dst_port, config.position, diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index a6e235c5..fefa22b8 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -84,7 +84,7 @@ class ActionManager(BaseModel): def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" act_class = AbstractAction._registry[action_identifier] - config = act_class.ConfigSchema(**action_options) + config = act_class.ConfigSchema(type=action_identifier, **action_options) return act_class.form_request(config=config) @property diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 3311de66..f5714644 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING +from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict, Field @@ -48,11 +48,20 @@ class AbstractAgent(BaseModel, ABC): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + class AgentSettingsSchema(BaseModel, ABC): + """Schema for the 'agent_settings' key.""" + + model_config = ConfigDict(extra="forbid") + class ConfigSchema(BaseModel, ABC): """Configuration Schema for AbstractAgents.""" model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) type: str + ref: str + """name of the agent.""" + team: Optional[Literal["BLUE", "GREEN", "RED"]] + agent_settings: AbstractAgent.AgentSettingsSchema = Field(default=lambda: AbstractAgent.AgentSettingsSchema()) action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema()) observation_space: ObservationManager.ConfigSchema = Field( default_factory=lambda: ObservationManager.ConfigSchema() @@ -85,11 +94,6 @@ class AbstractAgent(BaseModel, ABC): self.reward_function = RewardFunction(config=self.config.reward_function) return super().model_post_init(__context) - @property - def flatten_obs(self) -> bool: - """Return agent flatten_obs param.""" - return self.config.flatten_obs - def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. @@ -149,6 +153,13 @@ class AbstractAgent(BaseModel, ABC): """Update the most recent history item with the reward value.""" self.history[-1].reward = self.reward_function.current_reward + @classmethod + def from_config(cls, config: Dict) -> AbstractAgent: + """Grab the relevatn agent class and construct an instance from a config dict.""" + agent_type = config["type"] + agent_class = cls._registry[agent_type] + return agent_class(config=config) + class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"): """Base class for actors which generate their own behaviour.""" @@ -172,12 +183,17 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema()) most_recent_action: ActType = None + class AgentSettingsSchema(AbstractAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" + + flatten_obs: bool = False + action_masking: bool = False + class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Proxy Agent.""" type: str = "Proxy_Agent" - flatten_obs: bool = False - action_masking: bool = False + agent_settings: ProxyAgent.AgentSettingsSchema = Field(default_factory=lambda: ProxyAgent.AgentSettingsSchema()) def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ @@ -199,3 +215,8 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): The environment is responsible for calling this method when it receives an action from the agent policy. """ self.most_recent_action = action + + @property + def flatten_obs(self) -> bool: + """Return agent flatten_obs param.""" + return self.config.agent_settings.flatten_obs diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index d4c8ef9b..2881f967 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -377,7 +377,7 @@ class SharedReward(AbstractReward, identifier="SHARED_REWARD"): class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): - """Apply a negative reward when taking any action except DONOTHING.""" + """Apply a negative reward when taking any action except do_nothing.""" config: "ActionPenalty.ConfigSchema" diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py index 21323578..e6ddd546 100644 --- a/src/primaite/game/agent/scripted_agents/abstract_tap.py +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -3,27 +3,36 @@ from __future__ import annotations import random from abc import abstractmethod -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple from gymnasium.core import ObsType from pydantic import Field -from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent __all__ = "AbstractTAPAgent" -class AbstractTAPAgent(AbstractScriptedAgent, identifier="AbstractTAP"): +class AbstractTAPAgent(PeriodicAgent, identifier="AbstractTAP"): """Base class for TAP agents to inherit from.""" config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema()) next_execution_timestep: int = 0 - class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" + + possible_starting_nodes: List[str] = Field(default_factory=list) + + class ConfigSchema(PeriodicAgent.ConfigSchema): """Configuration schema for Abstract TAP agents.""" type: str = "AbstractTAP" - starting_node_name: Optional[str] = None + agent_settings: AbstractTAPAgent.AgentSettingsSchema = Field( + default_factory=lambda: AbstractTAPAgent.AgentSettingsSchema() + ) + + starting_node: Optional[str] = None @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -40,13 +49,13 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="AbstractTAP"): :param timestep: The timestep to add variance to. """ - random_timestep_increment = random.randint(-self.config.variance, self.config.variance) + random_timestep_increment = random.randint( + -self.config.agent_settings.variance, self.config.agent_settings.variance + ) self.next_execution_timestep = timestep + random_timestep_increment def _select_start_node(self) -> None: """Set the starting starting node of the agent to be a random node from this agent's action manager.""" # we are assuming that every node in the node manager has a data manipulation application at idx 0 - num_nodes = len(self.action_manager.node_names) - starting_node_idx = random.randint(0, num_nodes - 1) - self.config.starting_node_name = self.action_manager.node_names[starting_node_idx] - self.logger.debug(f"Selected starting node: {self.config.starting_node_name}") + self.starting_node = random.choice(self.config.agent_settings.possible_starting_nodes) + self.logger.debug(f"Selected starting node: {self.starting_node}") diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 8fe0690b..a7558d42 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -1,6 +1,5 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -import random -from typing import Dict, List, Tuple +from typing import Dict, Tuple from gymnasium.core import ObsType from pydantic import Field @@ -13,21 +12,24 @@ __all__ = "DataManipulationAgent" class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" + + target_application: str = "DataManipulationBot" + class ConfigSchema(PeriodicAgent.ConfigSchema): """Configuration Schema for DataManipulationAgent.""" type: str = "RedDatabaseCorruptingAgent" - starting_application_name: str = "DataManipulationBot" - possible_start_nodes: List[str] + agent_settings: "DataManipulationAgent.AgentSettingsSchema" = Field( + default_factory=lambda: DataManipulationAgent.AgentSettingsSchema() + ) config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema()) - start_node: str - def __init__(self, **kwargs): - kwargs["start_node"] = random.choice(kwargs["config"].possible_start_nodes) super().__init__(**kwargs) - self._set_next_execution_timestep(timestep=self.config.start_step, variance=0) + self._set_next_execution_timestep(timestep=self.config.agent_settings.start_step, variance=0) def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: """Waits until a specific timestep, then attempts to execute its data manipulation application. @@ -43,9 +45,11 @@ class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgen self.logger.debug(msg="Performing do nothing action") return "do_nothing", {} - self._set_next_execution_timestep(timestep=timestep + self.config.frequency, variance=self.config.variance) + self._set_next_execution_timestep( + timestep=timestep + self.config.agent_settings.frequency, variance=self.config.agent_settings.variance + ) self.logger.info(msg="Performing a data manipulation attack!") return "node_application_execute", { "node_name": self.start_node, - "application_name": self.config.starting_application_name, + "application_name": self.config.agent_settings.target_application, } diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index 8e714f55..20924a95 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -19,10 +19,8 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema()) rng: Generator = np.random.default_rng(np.random.randint(0, 65535)) - class ConfigSchema(AbstractScriptedAgent.ConfigSchema): - """Configuration schema for Probabilistic Agent.""" - - type: str = "ProbabilisticAgent" + class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" action_probabilities: Dict[int, float] = None """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" @@ -46,10 +44,18 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") ) return v + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration schema for Probabilistic Agent.""" + + type: str = "ProbabilisticAgent" + agent_settings: "ProbabilisticAgent.AgentSettingsSchema" = Field( + default_factory=lambda: ProbabilisticAgent.AgentSettingsSchema() + ) + @property def probabilities(self) -> Dict[str, int]: """Convenience method to view the probabilities of the Agent.""" - return np.asarray(list(self.config.action_probabilities.values())) + return np.asarray(list(self.config.agent_settings.action_probabilities.values())) def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index 999669d8..721b5293 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -1,9 +1,10 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK import random -from typing import Dict, Tuple +from functools import cached_property +from typing import Dict, List, Tuple from gymnasium.core import ObsType -from pydantic import Field, model_validator +from pydantic import computed_field, Field, model_validator from primaite.game.agent.interface import AbstractScriptedAgent @@ -38,17 +39,17 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema()) - class ConfigSchema(AbstractScriptedAgent.ConfigSchema): - """Configuration Schema for Periodic Agent.""" + class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema): + """Schema for the `agent_settings` part of the agent config.""" - type: str = "PeriodicAgent" - """Name of the agent.""" start_step: int = 5 "The timestep at which an agent begins performing it's actions" frequency: int = 5 "The number of timesteps to wait between performing actions" variance: int = 0 "The amount the frequency can randomly change to" + possible_start_nodes: List[str] + target_application: str @model_validator(mode="after") def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema": @@ -66,6 +67,15 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): ) return self + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration Schema for Periodic Agent.""" + + type: str = "PeriodicAgent" + """Name of the agent.""" + agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field( + default_factory=lambda: PeriodicAgent.AgentSettingsSchema() + ) + max_executions: int = 999999 "Maximum number of times the agent can execute its action." num_executions: int = 0 @@ -73,6 +83,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): next_execution_timestep: int = 0 """Timestep of the next action execution by the agent.""" + @computed_field + @cached_property + def start_node(self) -> str: + """On instantiation, randomly select a start node.""" + return random.choice(self.config.agent_settings.possible_start_nodes) + def _set_next_execution_timestep(self, timestep: int, variance: int) -> None: """Set the next execution timestep with a configured random variance. @@ -88,8 +104,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): """Do nothing, unless the current timestep is the next execution timestep, in which case do the action.""" if timestep == self.next_execution_timestep and self.num_executions < self.max_executions: self.num_executions += 1 - self._set_next_execution_timestep(timestep + self.config.frequency, self.config.variance) - self.target_node = self.action_manager.node_names[0] - return "node_application_execute", {"node_name": self.target_node, "application_name": 0} + self._set_next_execution_timestep( + timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance + ) + return "node_application_execute", { + "node_name": self.start_node, + "application_name": self.config.agent_settings.target_application, + } return "do_nothing", {} diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 5220e874..650d6a10 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -525,33 +525,37 @@ class PrimaiteGame: agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: - agent_name = agent_cfg["ref"] # noqa: F841 - agent_type = agent_cfg["type"] - action_space_cfg = agent_cfg["action_space"] - observation_space_cfg = agent_cfg["observation_space"] - reward_function_cfg = agent_cfg["reward_function"] - agent_settings = agent_cfg["agent_settings"] - - agent_config = { - "type": agent_type, - "action_space": action_space_cfg, - "observation_space": observation_space_cfg, - "reward_function": reward_function_cfg, - "agent_settings": agent_settings, - "game": game, - } - - # CREATE AGENT - if agent_type in AbstractAgent._registry: - new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config) - # If blue agent is created, add to game.rl_agents - if agent_type == "ProxyAgent": - game.rl_agents[agent_cfg["ref"]] = new_agent - else: - msg = f"Configuration error: {agent_type} is not a valid agent type." - _LOGGER.error(msg) - raise ValueError(msg) + new_agent = AbstractAgent.from_config(agent_cfg) game.agents[agent_cfg["ref"]] = new_agent + if isinstance(new_agent, ProxyAgent): + game.rl_agents[agent_cfg["ref"]] = new_agent + + # agent_name = agent_cfg["ref"] # noqa: F841 + # agent_type = agent_cfg["type"] + # action_space_cfg = agent_cfg["action_space"] + # observation_space_cfg = agent_cfg["observation_space"] + # reward_function_cfg = agent_cfg["reward_function"] + # agent_settings = agent_cfg["agent_settings"] + + # agent_config = { + # "type": agent_type, + # "action_space": action_space_cfg, + # "observation_space": observation_space_cfg, + # "reward_function": reward_function_cfg, + # "agent_settings": agent_settings, + # "game": game, + # } + + # # CREATE AGENT + # if agent_type in AbstractAgent._registry: + # new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config) + # # If blue agent is created, add to game.rl_agents + # if agent_type == "ProxyAgent": + # game.rl_agents[agent_cfg["ref"]] = new_agent + # else: + # msg = f"Configuration error: {agent_type} is not a valid agent type." + # _LOGGER.error(msg) + # raise ValueError(msg) # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. game.setup_reward_sharing() diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 29f7c33d..b7a9a042 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -89,7 +89,7 @@ class PrimaiteGymEnv(gymnasium.Env): :return: Action mask :rtype: List[bool] """ - if not self.agent.config.action_masking: + if not self.agent.config.agent_settings.action_masking: return np.asarray([True] * len(self.agent.action_manager.action_map)) else: return self.game.action_mask(self._agent_name) diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 0c96714e..16c85cb3 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): ) for agent_name in self._agent_ids: agent = self.game.rl_agents[agent_name] - if agent.config.action_masking: + if agent.config.agent_settings.action_masking: self.observation_space[agent_name] = spaces.Dict( { "action_mask": spaces.MultiBinary(agent.action_manager.space.n), @@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation obs = gymnasium.spaces.flatten(unflat_space, unflat_obs) - if agent.config.action_masking: + if agent.config.agent_settings.action_masking: all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs} else: all_obs[agent_name] = obs @@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" super().reset() # Ensure PRNG seed is set everywhere - if self.env.agent.config.action_masking: + if self.env.agent.config.agent_settings.action_masking: obs, *_ = self.env.reset(seed=seed) new_obs = {"action_mask": self.env.action_masks(), "observations": obs} return new_obs, *_ @@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env): def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" # if action masking is enabled, intercept the step method and add action mask to observation - if self.env.agent.config.action_masking: + if self.env.agent.config.agent_settings.action_masking: obs, *_ = self.env.step(action) new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs} return new_obs, *_ diff --git a/tests/conftest.py b/tests/conftest.py index c8c5e694..aa4a0ef0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -414,74 +414,13 @@ def game_and_agent(): sim = game.simulation install_stuff_to_sim(sim) - actions = [ - {"type": "do_nothing"}, - {"type": "node_service_scan"}, - {"type": "node_service_stop"}, - {"type": "node_service_start"}, - {"type": "node_service_pause"}, - {"type": "node_service_resume"}, - {"type": "node_service_restart"}, - {"type": "node_service_disable"}, - {"type": "node_service_enable"}, - {"type": "node_service_fix"}, - {"type": "node_application_execute"}, - {"type": "node_application_scan"}, - {"type": "node_application_close"}, - {"type": "node_application_fix"}, - {"type": "node_application_install"}, - {"type": "node_application_remove"}, - {"type": "node_file_create"}, - {"type": "node_file_scan"}, - {"type": "node_file_checkhash"}, - {"type": "node_file_delete"}, - {"type": "node_file_repair"}, - {"type": "node_file_restore"}, - {"type": "node_file_corrupt"}, - {"type": "node_file_access"}, - {"type": "node_folder_create"}, - {"type": "node_folder_scan"}, - {"type": "node_folder_checkhash"}, - {"type": "node_folder_repair"}, - {"type": "node_folder_restore"}, - {"type": "node_os_scan"}, - {"type": "node_shutdown"}, - {"type": "node_startup"}, - {"type": "node_reset"}, - {"type": "router_acl_add_rule"}, - {"type": "router_acl_remove_rule"}, - {"type": "host_nic_enable"}, - {"type": "host_nic_disable"}, - {"type": "network_port_enable"}, - {"type": "network_port_disable"}, - {"type": "configure_c2_beacon"}, - {"type": "c2_server_ransomware_launch"}, - {"type": "c2_server_ransomware_configure"}, - {"type": "c2_server_terminal_command"}, - {"type": "c2_server_data_exfiltrate"}, - {"type": "node_account_change_password"}, - {"type": "node_session_remote_login"}, - {"type": "node_session_remote_logoff"}, - {"type": "node_send_remote_command"}, - ] - - action_space = ActionManager( - actions=actions, # ALL POSSIBLE ACTIONS - act_map={}, - ) - observation_space = ObservationManager(NestedObservation(components={})) - reward_function = RewardFunction() - config = { "type": "ControlledAgent", - "agent_name": "test_agent", - "action_manager": action_space, - "observation_manager": observation_space, - "reward_function": reward_function, - "agent_settings": {}, + "ref": "test_agent", + "team": "BLUE", } - test_agent = ControlledAgent.from_config(config=config) + test_agent = ControlledAgent(config=config) game.agents["test_agent"] = test_agent