#2869 - Updates to agents to make sure they can be generated from a given config. Updates to test suite to reflect code changes
This commit is contained in:
@@ -3,11 +3,11 @@ from __future__ import annotations
|
||||
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.scripted_agents.interface import AbstractAgent, AbstractScriptedAgent
|
||||
|
||||
|
||||
class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
|
||||
@@ -20,7 +20,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Abstract TAP agents."""
|
||||
|
||||
starting_node_name: str
|
||||
starting_node_name: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agent"):
|
||||
class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingAgent"):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
config: "DataManipulationAgent.ConfigSchema"
|
||||
@@ -14,12 +14,12 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
|
||||
class ConfigSchema(AbstractTAPAgent.ConfigSchema):
|
||||
"""Configuration Schema for DataManipulationAgent."""
|
||||
starting_application_name: Optional[str] = None
|
||||
|
||||
starting_application_name: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise DataManipulationAgent."""
|
||||
self.setup_agent()
|
||||
# def __init__(self, **kwargs: Any) -> None:
|
||||
# """Initialise DataManipulationAgent."""
|
||||
# # self.setup_agent()
|
||||
# super().__init_subclass__(**kwargs)
|
||||
|
||||
@property
|
||||
def next_execution_timestep(self) -> int:
|
||||
@@ -41,11 +41,15 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if self.starting_node_name or self.config is None:
|
||||
self.setup_agent()
|
||||
self.get_action(obs=obs, timestep=timestep)
|
||||
|
||||
if timestep < self.next_execution_timestep:
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.config._agent_settings.start_settings.frequency)
|
||||
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "node_application_execute", {
|
||||
"node_name": self.config.starting_node_name,
|
||||
@@ -55,4 +59,4 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.config._agent_settings.start_settings.start_step)
|
||||
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
|
||||
|
||||
@@ -115,14 +115,14 @@ class AbstractAgent(BaseModel):
|
||||
:type agent_settings: Optional[AgentSettings]
|
||||
"""
|
||||
|
||||
agent_name: str = "Abstract_Agent"
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
agent_name: Optional[str] = "Abstract_Agent"
|
||||
history: List[AgentHistoryItem] = []
|
||||
_logger: AgentLog = AgentLog(agent_name=agent_name)
|
||||
_action_manager: Optional[ActionManager] = None
|
||||
_observation_manager: Optional[ObservationManager] = None
|
||||
_reward_function: Optional[RewardFunction] = None
|
||||
_agent_settings: Optional[AgentSettings] = None
|
||||
action_manager: ActionManager
|
||||
observation_manager: ObservationManager
|
||||
reward_function: RewardFunction
|
||||
agent_settings: Optional[AgentSettings] = None
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
if identifier in cls._registry:
|
||||
@@ -138,17 +138,17 @@ class AbstractAgent(BaseModel):
|
||||
@property
|
||||
def observation_manager(self) -> ObservationManager:
|
||||
"""Returns the agents observation manager."""
|
||||
return self.config._observation_manager
|
||||
return self.config.observation_manager
|
||||
|
||||
@property
|
||||
def action_manager(self) -> ActionManager:
|
||||
"""Returns the agents action manager."""
|
||||
return self.config._action_manager
|
||||
return self.config.action_manager
|
||||
|
||||
@property
|
||||
def reward_function(self) -> RewardFunction:
|
||||
"""Returns the agents reward function."""
|
||||
return self.config._reward_function
|
||||
return self.config.reward_function
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
@@ -232,7 +232,7 @@ class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent")
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"):
|
||||
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
config: "ProxyAgent.ConfigSchema"
|
||||
|
||||
@@ -1,28 +1,23 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Dict, Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings
|
||||
|
||||
|
||||
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
config: "ProbabilisticAgent.ConfigSchema"
|
||||
agent_name: str = "ProbabilisticAgent"
|
||||
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
action_space: ActionManager
|
||||
class AgentSettings(AgentSettings):
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
|
||||
@@ -42,16 +37,17 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
|
||||
)
|
||||
return v
|
||||
|
||||
# def __init__(self, **kwargs) -> None:
|
||||
# rng_seed = np.random.randint(0, 65535)
|
||||
# self.rng = np.random.default_rng(rng_seed)
|
||||
# self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
# super().__init_subclass__(**kwargs)
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
agent_name: str = "ProbabilisticAgent"
|
||||
agent_settings: "ProbabilisticAgent.AgentSettings"
|
||||
|
||||
|
||||
@property
|
||||
def probabilities(self) -> Dict[str, int]:
|
||||
"""Convenience method to view the probabilities of the Agent."""
|
||||
return np.asarray(list(self.config.action_probabilities.values()))
|
||||
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@ class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Random Agents."""
|
||||
|
||||
agent_name = "Random_Agent"
|
||||
agent_name: str = "Random_Agent"
|
||||
|
||||
def get_action(self) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
@@ -36,7 +36,7 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Periodic Agent."""
|
||||
|
||||
agent_name = "Periodic_Agent"
|
||||
agent_name: str = "Periodic_Agent"
|
||||
"""Name of the agent."""
|
||||
start_step: int = 20
|
||||
"The timestep at which an agent begins performing it's actions."
|
||||
|
||||
@@ -174,7 +174,7 @@ class PrimaiteGame:
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
|
||||
if SIM_OUTPUT.save_agent_logs:
|
||||
agent.config.logger.debug(f"Chosen Action: {action_choice}")
|
||||
agent.logger.debug(f"Chosen Action: {action_choice}")
|
||||
request = agent.format_request(action_choice, parameters)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent.process_action_response(
|
||||
@@ -544,14 +544,20 @@ class PrimaiteGame:
|
||||
|
||||
# CREATE AGENT
|
||||
|
||||
agent_config = agent_cfg.get("agent_settings", {})
|
||||
agent_config.update(
|
||||
{"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function}
|
||||
)
|
||||
# new_agent_cfg.update{}
|
||||
print(AbstractAgent._registry)
|
||||
agent_settings = agent_cfg["agent_settings"]
|
||||
agent_config = {
|
||||
"agent_name": agent_ref,
|
||||
"action_manager": action_space,
|
||||
"observation_manager": obs_space,
|
||||
"reward_function": reward_function,
|
||||
"agent_settings": agent_settings,
|
||||
}
|
||||
|
||||
# new_agent_cfg.update{}
|
||||
if agent_type in AbstractAgent._registry:
|
||||
print(agent_type)
|
||||
print(agent_config)
|
||||
print(AbstractAgent._registry)
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
|
||||
# If blue agent is created, add to game.rl_agents
|
||||
if agent_type == "ProxyAgent":
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -10,6 +10,7 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents.interface import AbstractAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.container import Network
|
||||
@@ -268,12 +269,12 @@ class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
|
||||
"""Agent that can be controlled by the tests."""
|
||||
|
||||
config: "ControlledAgent.ConfigSchema"
|
||||
most_recent_action: Optional[Tuple[str, Dict]] = None
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Abstract Agent used in tests."""
|
||||
|
||||
agent_name: str = "Controlled_Agent"
|
||||
most_recent_action: Tuple[str, Dict]
|
||||
|
||||
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return the agent's most recent action, formatted in CAOS format."""
|
||||
@@ -496,12 +497,15 @@ def game_and_agent():
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
test_agent = ControlledAgent(
|
||||
agent_name="test_agent",
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
|
||||
config = {
|
||||
"agent_name":"test_agent",
|
||||
"action_manager":action_space,
|
||||
"observation_manager":observation_space,
|
||||
"reward_function":reward_function,
|
||||
}
|
||||
|
||||
test_agent = ControlledAgent.from_config(config=config)
|
||||
|
||||
game.agents["test_agent"] = test_agent
|
||||
|
||||
|
||||
@@ -55,15 +55,25 @@ def test_probabilistic_agent():
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
pa = ProbabilisticAgent(
|
||||
agent_name="test_agent",
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
settings={
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
},
|
||||
)
|
||||
# pa = ProbabilisticAgent(
|
||||
# agent_name="test_agent",
|
||||
# action_space=action_space,
|
||||
# observation_space=observation_space,
|
||||
# reward_function=reward_function,
|
||||
# settings={
|
||||
# "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
# },
|
||||
# )
|
||||
|
||||
pa_config = {"agent_name":"test_agent",
|
||||
"action_manager": action_space,
|
||||
"observation_manager": observation_space,
|
||||
"reward_function": reward_function,
|
||||
"agent_settings": {
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
}}
|
||||
|
||||
pa = ProbabilisticAgent.from_config(config=pa_config)
|
||||
|
||||
do_nothing_count = 0
|
||||
node_application_execute_count = 0
|
||||
|
||||
Reference in New Issue
Block a user