#2869 - Refactor agent and action config system
This commit is contained in:
@@ -2,15 +2,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from pydantic import field_validator
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import protocol_validator
|
||||
from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address
|
||||
from primaite.utils.validation.port import port_validator
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
__all__ = (
|
||||
"RouterACLAddRuleAction",
|
||||
@@ -29,43 +27,14 @@ class ACLAddRuleAbstractAction(AbstractAction, ABC):
|
||||
"""Configuration Schema base for ACL add rule abstract actions."""
|
||||
|
||||
src_ip: IPV4Address
|
||||
protocol_name: str
|
||||
permission: str
|
||||
protocol_name: Union[IPProtocol, Literal["ALL"]]
|
||||
permission: Literal["ALLOW", "DENY"]
|
||||
position: int
|
||||
dst_ip: IPV4Address
|
||||
src_port: int
|
||||
dst_port: int
|
||||
src_wildcard: int
|
||||
dst_wildcard: int
|
||||
|
||||
@field_validator(
|
||||
"src_port",
|
||||
"dst_port",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def valid_port(cls, v: str) -> int:
|
||||
"""Check that inputs are valid."""
|
||||
return port_validator(v)
|
||||
|
||||
@field_validator(
|
||||
"src_ip",
|
||||
"dst_ip",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def valid_ip(cls, v: str) -> str:
|
||||
"""Check that a valid IP has been provided for src and dst."""
|
||||
return ipv4_validator(v)
|
||||
|
||||
@field_validator(
|
||||
"protocol_name",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def is_valid_protocol(cls, v: str) -> bool:
|
||||
"""Check that we are using a valid protocol."""
|
||||
return protocol_validator(v)
|
||||
dst_ip: Union[IPV4Address, Literal["ALL"]]
|
||||
src_port: Union[Port, Literal["ALL"]]
|
||||
dst_port: Union[Port, Literal["ALL"]]
|
||||
src_wildcard: Union[IPV4Address, Literal["NONE"]]
|
||||
dst_wildcard: Union[IPV4Address, Literal["NONE"]]
|
||||
|
||||
|
||||
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
|
||||
@@ -100,10 +69,10 @@ class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_ad
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
config.src_ip,
|
||||
str(config.src_ip),
|
||||
config.src_wildcard,
|
||||
config.src_port,
|
||||
config.dst_ip,
|
||||
str(config.dst_ip),
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
@@ -139,7 +108,7 @@ class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_ac
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
@@ -151,10 +120,10 @@ class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_ac
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
config.src_ip,
|
||||
str(config.src_ip),
|
||||
config.src_wildcard,
|
||||
config.src_port,
|
||||
config.dst_ip,
|
||||
str(config.dst_ip),
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
|
||||
@@ -84,7 +84,7 @@ class ActionManager(BaseModel):
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_class = AbstractAction._registry[action_identifier]
|
||||
config = act_class.ConfigSchema(**action_options)
|
||||
config = act_class.ConfigSchema(type=action_identifier, **action_options)
|
||||
return act_class.form_request(config=config)
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -48,11 +48,20 @@ class AbstractAgent(BaseModel, ABC):
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
class AgentSettingsSchema(BaseModel, ABC):
|
||||
"""Schema for the 'agent_settings' key."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configuration Schema for AbstractAgents."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
type: str
|
||||
ref: str
|
||||
"""name of the agent."""
|
||||
team: Optional[Literal["BLUE", "GREEN", "RED"]]
|
||||
agent_settings: AbstractAgent.AgentSettingsSchema = Field(default=lambda: AbstractAgent.AgentSettingsSchema())
|
||||
action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
observation_space: ObservationManager.ConfigSchema = Field(
|
||||
default_factory=lambda: ObservationManager.ConfigSchema()
|
||||
@@ -85,11 +94,6 @@ class AbstractAgent(BaseModel, ABC):
|
||||
self.reward_function = RewardFunction(config=self.config.reward_function)
|
||||
return super().model_post_init(__context)
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
"""Return agent flatten_obs param."""
|
||||
return self.config.flatten_obs
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
@@ -149,6 +153,13 @@ class AbstractAgent(BaseModel, ABC):
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractAgent:
|
||||
"""Grab the relevatn agent class and construct an instance from a config dict."""
|
||||
agent_type = config["type"]
|
||||
agent_class = cls._registry[agent_type]
|
||||
return agent_class(config=config)
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
@@ -172,12 +183,17 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema())
|
||||
most_recent_action: ActType = None
|
||||
|
||||
class AgentSettingsSchema(AbstractAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
flatten_obs: bool = False
|
||||
action_masking: bool = False
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
type: str = "Proxy_Agent"
|
||||
flatten_obs: bool = False
|
||||
action_masking: bool = False
|
||||
agent_settings: ProxyAgent.AgentSettingsSchema = Field(default_factory=lambda: ProxyAgent.AgentSettingsSchema())
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
@@ -199,3 +215,8 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
"""Return agent flatten_obs param."""
|
||||
return self.config.agent_settings.flatten_obs
|
||||
|
||||
@@ -377,7 +377,7 @@ class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
"""Apply a negative reward when taking any action except do_nothing."""
|
||||
|
||||
config: "ActionPenalty.ConfigSchema"
|
||||
|
||||
|
||||
@@ -3,27 +3,36 @@ from __future__ import annotations
|
||||
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
|
||||
__all__ = "AbstractTAPAgent"
|
||||
|
||||
|
||||
class AbstractTAPAgent(AbstractScriptedAgent, identifier="AbstractTAP"):
|
||||
class AbstractTAPAgent(PeriodicAgent, identifier="AbstractTAP"):
|
||||
"""Base class for TAP agents to inherit from."""
|
||||
|
||||
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
|
||||
next_execution_timestep: int = 0
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
possible_starting_nodes: List[str] = Field(default_factory=list)
|
||||
|
||||
class ConfigSchema(PeriodicAgent.ConfigSchema):
|
||||
"""Configuration schema for Abstract TAP agents."""
|
||||
|
||||
type: str = "AbstractTAP"
|
||||
starting_node_name: Optional[str] = None
|
||||
agent_settings: AbstractTAPAgent.AgentSettingsSchema = Field(
|
||||
default_factory=lambda: AbstractTAPAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
starting_node: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -40,13 +49,13 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="AbstractTAP"):
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(-self.config.variance, self.config.variance)
|
||||
random_timestep_increment = random.randint(
|
||||
-self.config.agent_settings.variance, self.config.agent_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.config.starting_node_name = self.action_manager.node_names[starting_node_idx]
|
||||
self.logger.debug(f"Selected starting node: {self.config.starting_node_name}")
|
||||
self.starting_node = random.choice(self.config.agent_settings.possible_starting_nodes)
|
||||
self.logger.debug(f"Selected starting node: {self.starting_node}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
@@ -13,21 +12,24 @@ __all__ = "DataManipulationAgent"
|
||||
class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgent"):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
target_application: str = "DataManipulationBot"
|
||||
|
||||
class ConfigSchema(PeriodicAgent.ConfigSchema):
|
||||
"""Configuration Schema for DataManipulationAgent."""
|
||||
|
||||
type: str = "RedDatabaseCorruptingAgent"
|
||||
starting_application_name: str = "DataManipulationBot"
|
||||
possible_start_nodes: List[str]
|
||||
agent_settings: "DataManipulationAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: DataManipulationAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema())
|
||||
|
||||
start_node: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["start_node"] = random.choice(kwargs["config"].possible_start_nodes)
|
||||
super().__init__(**kwargs)
|
||||
self._set_next_execution_timestep(timestep=self.config.start_step, variance=0)
|
||||
self._set_next_execution_timestep(timestep=self.config.agent_settings.start_step, variance=0)
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
|
||||
@@ -43,9 +45,11 @@ class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgen
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep=timestep + self.config.frequency, variance=self.config.variance)
|
||||
self._set_next_execution_timestep(
|
||||
timestep=timestep + self.config.agent_settings.frequency, variance=self.config.agent_settings.variance
|
||||
)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "node_application_execute", {
|
||||
"node_name": self.start_node,
|
||||
"application_name": self.config.starting_application_name,
|
||||
"application_name": self.config.agent_settings.target_application,
|
||||
}
|
||||
|
||||
@@ -19,10 +19,8 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
|
||||
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
|
||||
rng: Generator = np.random.default_rng(np.random.randint(0, 65535))
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
type: str = "ProbabilisticAgent"
|
||||
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
action_probabilities: Dict[int, float] = None
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
@@ -46,10 +44,18 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
|
||||
)
|
||||
return v
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
type: str = "ProbabilisticAgent"
|
||||
agent_settings: "ProbabilisticAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: ProbabilisticAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
@property
|
||||
def probabilities(self) -> Dict[str, int]:
|
||||
"""Convenience method to view the probabilities of the Agent."""
|
||||
return np.asarray(list(self.config.action_probabilities.values()))
|
||||
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import computed_field, Field, model_validator
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
@@ -38,17 +39,17 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
|
||||
config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema())
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Periodic Agent."""
|
||||
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
type: str = "PeriodicAgent"
|
||||
"""Name of the agent."""
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
possible_start_nodes: List[str]
|
||||
target_application: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
|
||||
@@ -66,6 +67,15 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
)
|
||||
return self
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Periodic Agent."""
|
||||
|
||||
type: str = "PeriodicAgent"
|
||||
"""Name of the agent."""
|
||||
agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: PeriodicAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
num_executions: int = 0
|
||||
@@ -73,6 +83,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
next_execution_timestep: int = 0
|
||||
"""Timestep of the next action execution by the agent."""
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def start_node(self) -> str:
|
||||
"""On instantiation, randomly select a start node."""
|
||||
return random.choice(self.config.agent_settings.possible_start_nodes)
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
@@ -88,8 +104,12 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
|
||||
self.num_executions += 1
|
||||
self._set_next_execution_timestep(timestep + self.config.frequency, self.config.variance)
|
||||
self.target_node = self.action_manager.node_names[0]
|
||||
return "node_application_execute", {"node_name": self.target_node, "application_name": 0}
|
||||
self._set_next_execution_timestep(
|
||||
timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance
|
||||
)
|
||||
return "node_application_execute", {
|
||||
"node_name": self.start_node,
|
||||
"application_name": self.config.agent_settings.target_application,
|
||||
}
|
||||
|
||||
return "do_nothing", {}
|
||||
|
||||
@@ -525,33 +525,37 @@ class PrimaiteGame:
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_name = agent_cfg["ref"] # noqa: F841
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
agent_settings = agent_cfg["agent_settings"]
|
||||
|
||||
agent_config = {
|
||||
"type": agent_type,
|
||||
"action_space": action_space_cfg,
|
||||
"observation_space": observation_space_cfg,
|
||||
"reward_function": reward_function_cfg,
|
||||
"agent_settings": agent_settings,
|
||||
"game": game,
|
||||
}
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type in AbstractAgent._registry:
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
|
||||
# If blue agent is created, add to game.rl_agents
|
||||
if agent_type == "ProxyAgent":
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
else:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
new_agent = AbstractAgent.from_config(agent_cfg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
if isinstance(new_agent, ProxyAgent):
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# agent_name = agent_cfg["ref"] # noqa: F841
|
||||
# agent_type = agent_cfg["type"]
|
||||
# action_space_cfg = agent_cfg["action_space"]
|
||||
# observation_space_cfg = agent_cfg["observation_space"]
|
||||
# reward_function_cfg = agent_cfg["reward_function"]
|
||||
# agent_settings = agent_cfg["agent_settings"]
|
||||
|
||||
# agent_config = {
|
||||
# "type": agent_type,
|
||||
# "action_space": action_space_cfg,
|
||||
# "observation_space": observation_space_cfg,
|
||||
# "reward_function": reward_function_cfg,
|
||||
# "agent_settings": agent_settings,
|
||||
# "game": game,
|
||||
# }
|
||||
|
||||
# # CREATE AGENT
|
||||
# if agent_type in AbstractAgent._registry:
|
||||
# new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
|
||||
# # If blue agent is created, add to game.rl_agents
|
||||
# if agent_type == "ProxyAgent":
|
||||
# game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
# else:
|
||||
# msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
# _LOGGER.error(msg)
|
||||
# raise ValueError(msg)
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
@@ -89,7 +89,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
if not self.agent.config.action_masking:
|
||||
if not self.agent.config.agent_settings.action_masking:
|
||||
return np.asarray([True] * len(self.agent.action_manager.action_map))
|
||||
else:
|
||||
return self.game.action_mask(self._agent_name)
|
||||
|
||||
@@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
)
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
if agent.config.action_masking:
|
||||
if agent.config.agent_settings.action_masking:
|
||||
self.observation_space[agent_name] = spaces.Dict(
|
||||
{
|
||||
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
|
||||
@@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if agent.config.action_masking:
|
||||
if agent.config.agent_settings.action_masking:
|
||||
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
|
||||
else:
|
||||
all_obs[agent_name] = obs
|
||||
@@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.config.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.config.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
|
||||
return new_obs, *_
|
||||
|
||||
@@ -414,74 +414,13 @@ def game_and_agent():
|
||||
sim = game.simulation
|
||||
install_stuff_to_sim(sim)
|
||||
|
||||
actions = [
|
||||
{"type": "do_nothing"},
|
||||
{"type": "node_service_scan"},
|
||||
{"type": "node_service_stop"},
|
||||
{"type": "node_service_start"},
|
||||
{"type": "node_service_pause"},
|
||||
{"type": "node_service_resume"},
|
||||
{"type": "node_service_restart"},
|
||||
{"type": "node_service_disable"},
|
||||
{"type": "node_service_enable"},
|
||||
{"type": "node_service_fix"},
|
||||
{"type": "node_application_execute"},
|
||||
{"type": "node_application_scan"},
|
||||
{"type": "node_application_close"},
|
||||
{"type": "node_application_fix"},
|
||||
{"type": "node_application_install"},
|
||||
{"type": "node_application_remove"},
|
||||
{"type": "node_file_create"},
|
||||
{"type": "node_file_scan"},
|
||||
{"type": "node_file_checkhash"},
|
||||
{"type": "node_file_delete"},
|
||||
{"type": "node_file_repair"},
|
||||
{"type": "node_file_restore"},
|
||||
{"type": "node_file_corrupt"},
|
||||
{"type": "node_file_access"},
|
||||
{"type": "node_folder_create"},
|
||||
{"type": "node_folder_scan"},
|
||||
{"type": "node_folder_checkhash"},
|
||||
{"type": "node_folder_repair"},
|
||||
{"type": "node_folder_restore"},
|
||||
{"type": "node_os_scan"},
|
||||
{"type": "node_shutdown"},
|
||||
{"type": "node_startup"},
|
||||
{"type": "node_reset"},
|
||||
{"type": "router_acl_add_rule"},
|
||||
{"type": "router_acl_remove_rule"},
|
||||
{"type": "host_nic_enable"},
|
||||
{"type": "host_nic_disable"},
|
||||
{"type": "network_port_enable"},
|
||||
{"type": "network_port_disable"},
|
||||
{"type": "configure_c2_beacon"},
|
||||
{"type": "c2_server_ransomware_launch"},
|
||||
{"type": "c2_server_ransomware_configure"},
|
||||
{"type": "c2_server_terminal_command"},
|
||||
{"type": "c2_server_data_exfiltrate"},
|
||||
{"type": "node_account_change_password"},
|
||||
{"type": "node_session_remote_login"},
|
||||
{"type": "node_session_remote_logoff"},
|
||||
{"type": "node_send_remote_command"},
|
||||
]
|
||||
|
||||
action_space = ActionManager(
|
||||
actions=actions, # ALL POSSIBLE ACTIONS
|
||||
act_map={},
|
||||
)
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
config = {
|
||||
"type": "ControlledAgent",
|
||||
"agent_name": "test_agent",
|
||||
"action_manager": action_space,
|
||||
"observation_manager": observation_space,
|
||||
"reward_function": reward_function,
|
||||
"agent_settings": {},
|
||||
"ref": "test_agent",
|
||||
"team": "BLUE",
|
||||
}
|
||||
|
||||
test_agent = ControlledAgent.from_config(config=config)
|
||||
test_agent = ControlledAgent(config=config)
|
||||
|
||||
game.agents["test_agent"] = test_agent
|
||||
|
||||
|
||||
Reference in New Issue
Block a user