#2869 - Updates to agents to make sure they can be generated from a given config. Updates to test suite to reflect code changes

This commit is contained in:
Charlie Crane
2024-12-16 15:57:00 +00:00
parent d9a1a0e26f
commit a4fbd29bb4
8 changed files with 83 additions and 63 deletions

View File

@@ -3,11 +3,11 @@ from __future__ import annotations
import random
from abc import abstractmethod
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
from primaite.game.agent.scripted_agents.interface import AbstractAgent, AbstractScriptedAgent
class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
@@ -20,7 +20,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Abstract TAP agents."""
starting_node_name: str
starting_node_name: Optional[str] = None
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:

View File

@@ -1,12 +1,12 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, Tuple
from typing import Any, Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent
class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agent"):
class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingAgent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
config: "DataManipulationAgent.ConfigSchema"
@@ -14,12 +14,12 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
class ConfigSchema(AbstractTAPAgent.ConfigSchema):
"""Configuration Schema for DataManipulationAgent."""
starting_application_name: Optional[str] = None
starting_application_name: str
def __init__(self) -> None:
"""Initialise DataManipulationAgent."""
self.setup_agent()
# def __init__(self, **kwargs: Any) -> None:
# """Initialise DataManipulationAgent."""
# # self.setup_agent()
# super().__init_subclass__(**kwargs)
@property
def next_execution_timestep(self) -> int:
@@ -41,11 +41,15 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if self.starting_node_name or self.config is None:
self.setup_agent()
self.get_action(obs=obs, timestep=timestep)
if timestep < self.next_execution_timestep:
self.logger.debug(msg="Performing do nothing action")
return "do_nothing", {}
self._set_next_execution_timestep(timestep + self.config._agent_settings.start_settings.frequency)
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
self.logger.info(msg="Performing a data manipulation attack!")
return "node_application_execute", {
"node_name": self.config.starting_node_name,
@@ -55,4 +59,4 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.config._agent_settings.start_settings.start_step)
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)

View File

@@ -115,14 +115,14 @@ class AbstractAgent(BaseModel):
:type agent_settings: Optional[AgentSettings]
"""
agent_name: str = "Abstract_Agent"
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
agent_name: Optional[str] = "Abstract_Agent"
history: List[AgentHistoryItem] = []
_logger: AgentLog = AgentLog(agent_name=agent_name)
_action_manager: Optional[ActionManager] = None
_observation_manager: Optional[ObservationManager] = None
_reward_function: Optional[RewardFunction] = None
_agent_settings: Optional[AgentSettings] = None
action_manager: ActionManager
observation_manager: ObservationManager
reward_function: RewardFunction
agent_settings: Optional[AgentSettings] = None
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
if identifier in cls._registry:
@@ -138,17 +138,17 @@ class AbstractAgent(BaseModel):
@property
def observation_manager(self) -> ObservationManager:
"""Returns the agents observation manager."""
return self.config._observation_manager
return self.config.observation_manager
@property
def action_manager(self) -> ActionManager:
"""Returns the agents action manager."""
return self.config._action_manager
return self.config.action_manager
@property
def reward_function(self) -> RewardFunction:
"""Returns the agents reward function."""
return self.config._reward_function
return self.config.reward_function
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
@@ -232,7 +232,7 @@ class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent")
return super().get_action(obs=obs, timestep=timestep)
class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"):
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
"""Agent that sends observations to an RL model and receives actions from that model."""
config: "ProxyAgent.ConfigSchema"

View File

@@ -1,28 +1,23 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""Agents with predefined behaviours."""
from typing import Dict, Tuple
from typing import Any, Dict, Tuple
import numpy as np
import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
config: "ProbabilisticAgent.ConfigSchema"
agent_name: str = "ProbabilisticAgent"
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""
action_space: ActionManager
class AgentSettings(AgentSettings):
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
@@ -42,16 +37,17 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
)
return v
# def __init__(self, **kwargs) -> None:
# rng_seed = np.random.randint(0, 65535)
# self.rng = np.random.default_rng(rng_seed)
# self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
# super().__init_subclass__(**kwargs)
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""
agent_name: str = "ProbabilisticAgent"
agent_settings: "ProbabilisticAgent.AgentSettings"
@property
def probabilities(self) -> Dict[str, int]:
"""Convenience method to view the probabilities of the Agent."""
return np.asarray(list(self.config.action_probabilities.values()))
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -13,7 +13,7 @@ class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Random Agents."""
agent_name = "Random_Agent"
agent_name: str = "Random_Agent"
def get_action(self) -> Tuple[str, Dict]:
"""Sample the action space randomly.
@@ -36,7 +36,7 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Periodic Agent."""
agent_name = "Periodic_Agent"
agent_name: str = "Periodic_Agent"
"""Name of the agent."""
start_step: int = 20
"The timestep at which an agent begins performing it's actions."

View File

@@ -174,7 +174,7 @@ class PrimaiteGame:
obs = agent.observation_manager.current_observation
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
if SIM_OUTPUT.save_agent_logs:
agent.config.logger.debug(f"Chosen Action: {action_choice}")
agent.logger.debug(f"Chosen Action: {action_choice}")
request = agent.format_request(action_choice, parameters)
response = self.simulation.apply_request(request)
agent.process_action_response(
@@ -544,14 +544,20 @@ class PrimaiteGame:
# CREATE AGENT
agent_config = agent_cfg.get("agent_settings", {})
agent_config.update(
{"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function}
)
# new_agent_cfg.update{}
print(AbstractAgent._registry)
agent_settings = agent_cfg["agent_settings"]
agent_config = {
"agent_name": agent_ref,
"action_manager": action_space,
"observation_manager": obs_space,
"reward_function": reward_function,
"agent_settings": agent_settings,
}
# new_agent_cfg.update{}
if agent_type in AbstractAgent._registry:
print(agent_type)
print(agent_config)
print(AbstractAgent._registry)
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents
if agent_type == "ProxyAgent":

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple
import pytest
import yaml
@@ -10,6 +10,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.interface import AbstractAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network
@@ -268,12 +269,12 @@ class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
"""Agent that can be controlled by the tests."""
config: "ControlledAgent.ConfigSchema"
most_recent_action: Optional[Tuple[str, Dict]] = None
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Abstract Agent used in tests."""
agent_name: str = "Controlled_Agent"
most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
@@ -496,12 +497,15 @@ def game_and_agent():
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
test_agent = ControlledAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
config = {
"agent_name":"test_agent",
"action_manager":action_space,
"observation_manager":observation_space,
"reward_function":reward_function,
}
test_agent = ControlledAgent.from_config(config=config)
game.agents["test_agent"] = test_agent

View File

@@ -55,15 +55,25 @@ def test_probabilistic_agent():
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
pa = ProbabilisticAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
settings={
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
},
)
# pa = ProbabilisticAgent(
# agent_name="test_agent",
# action_space=action_space,
# observation_space=observation_space,
# reward_function=reward_function,
# settings={
# "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
# },
# )
pa_config = {"agent_name":"test_agent",
"action_manager": action_space,
"observation_manager": observation_space,
"reward_function": reward_function,
"agent_settings": {
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
}}
pa = ProbabilisticAgent.from_config(config=pa_config)
do_nothing_count = 0
node_application_execute_count = 0