#2869 - Updates to address test failures. Updated YAML configs to remove redundant start_settings

This commit is contained in:
Charlie Crane
2025-01-08 14:42:35 +00:00
parent 66d309871f
commit 7af9d3724f
28 changed files with 111 additions and 130 deletions

View File

@@ -151,10 +151,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -150,10 +150,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender_1
team: BLUE

View File

@@ -20,7 +20,6 @@ reds: &reds
- type: DUMMY
agent_settings:
start_settings:
start_step: 10
frequency: 10
variance: 0
start_step: 10
frequency: 10
variance: 0

View File

@@ -20,7 +20,6 @@ reds: &reds
- type: DUMMY
agent_settings:
start_settings:
start_step: 3
frequency: 2
variance: 1
start_step: 3
frequency: 2
variance: 1

View File

@@ -47,13 +47,14 @@ class AbstractAgent(BaseModel):
"""Base class for scripted and RL agents."""
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
_logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
history: List[AgentHistoryItem] = []
config: "AbstractAgent.ConfigSchema"
action_manager: "ActionManager"
observation_manager: "ObservationManager"
reward_function: "RewardFunction"
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
class ConfigSchema(BaseModel):
"""
@@ -114,11 +115,12 @@ class AbstractAgent(BaseModel):
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
print(config)
obj = cls(
config=cls.ConfigSchema(**config["agent_settings"]),
action_manager=ActionManager.from_config(**config["action_manager"]),
observation_manager=ObservationManager.from_config(**config["observation_space"]),
reward_function=RewardFunction.from_config(**config["reward_function"]),
action_manager=ActionManager.from_config(config["game"], config["action_manager"]),
observation_manager=ObservationManager.from_config(config["observation_manager"]),
reward_function=RewardFunction.from_config(config["reward_function"]),
)
return obj
@@ -140,7 +142,7 @@ class AbstractAgent(BaseModel):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state=state, last_action_response=self.config.history[-1])
return self.reward_function.update(state=state, last_action_response=self.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:

View File

@@ -39,9 +39,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
)
random_timestep_increment = random.randint(-self.config.variance, self.config.variance)
self.next_execution_timestep = timestep + random_timestep_increment
def _select_start_node(self) -> None:

View File

@@ -42,7 +42,7 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingA
self.logger.debug(msg="Performing do nothing action")
return "do_nothing", {}
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
self._set_next_execution_timestep(timestep + self.config.frequency)
self.logger.info(msg="Performing a data manipulation attack!")
return "node_application_execute", {
"node_name": self.config.starting_node_name,
@@ -52,4 +52,4 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingA
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step)
self._set_next_execution_timestep(self.config.start_step)

View File

@@ -6,7 +6,7 @@ import numpy as np
import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent, AgentSettings
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = "ProbabilisticAgent"
@@ -17,8 +17,10 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
config: "ProbabilisticAgent.ConfigSchema"
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
class AgentSettings(AgentSettings):
"""ProbabilisticAgent settings."""
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""
agent_name: str = "ProbabilisticAgent"
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
@@ -42,16 +44,10 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
)
return v
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""
agent_name: str = "ProbabilisticAgent"
agent_settings: "ProbabilisticAgent.AgentSettings"
@property
def probabilities(self) -> Dict[str, int]:
"""Convenience method to view the probabilities of the Agent."""
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
return np.asarray(list(self.config.action_probabilities.values()))
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -538,6 +538,7 @@ class PrimaiteGame:
"observation_manager": observation_space_cfg,
"reward_function": reward_function_cfg,
"agent_settings": agent_settings,
"game": game,
}
# CREATE AGENT

View File

@@ -89,7 +89,7 @@ class PrimaiteGymEnv(gymnasium.Env):
:return: Action mask
:rtype: List[bool]
"""
if not self.agent.action_masking:
if not self.agent.config.action_masking:
return np.asarray([True] * len(self.agent.action_manager.action_map))
else:
return self.game.action_mask(self._agent_name)

View File

@@ -31,10 +31,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: data_manipulation_attacker
team: RED
@@ -63,10 +62,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -56,10 +56,9 @@ agents:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.4
1: 0.6

View File

@@ -59,10 +59,9 @@ agents:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.6
1: 0.4

View File

@@ -151,10 +151,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -81,10 +81,9 @@ agents:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.4
1: 0.6

View File

@@ -35,10 +35,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: data_manipulation_attacker
team: RED
@@ -75,10 +74,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -151,10 +151,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -266,10 +266,9 @@ agents:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
start_step: 5
frequency: 4
variance: 3

View File

@@ -56,10 +56,9 @@ agents:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.4
1: 0.6

View File

@@ -150,10 +150,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender_1
team: BLUE

View File

@@ -20,7 +20,6 @@ reds: &reds
- type: DUMMY
agent_settings:
start_settings:
start_step: 10
frequency: 10
variance: 0
start_step: 10
frequency: 10
variance: 0

View File

@@ -20,7 +20,6 @@ reds: &reds
- type: DUMMY
agent_settings:
start_settings:
start_step: 3
frequency: 2
variance: 1
start_step: 3
frequency: 2
variance: 1

View File

@@ -146,10 +146,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -56,10 +56,9 @@ agents:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
start_step: 5
frequency: 4
variance: 3
action_probabilities:
0: 0.4
1: 0.6

View File

@@ -150,10 +150,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -43,10 +43,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
action_probabilities:
0: 1.0
@@ -86,10 +85,9 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple
import pytest
import yaml

View File

@@ -3,6 +3,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
def test_probabilistic_agent():
@@ -25,36 +26,41 @@ def test_probabilistic_agent():
MIN_NODE_FILE_DELETE = 5750
MAX_NODE_FILE_DELETE = 6250
action_space = ActionManager(
actions=[
action_space_cfg = {
"action_list": [
{"type": "DONOTHING"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_DELETE"},
],
nodes=[
"nodes": [
{
"node_name": "client_1",
"applications": [{"application_name": "WebBrowser"}],
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=2,
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
act_map={
"max_folders_per_node": 2,
"max_files_per_folder": 2,
"max_services_per_node": 2,
"max_applications_per_node": 2,
"max_nics_per_node": 2,
"max_acl_rules": 10,
"protocols": ["TCP", "UDP", "ICMP"],
"ports": ["HTTP", "DNS", "ARP"],
"act_map": {
0: {"action": "DONOTHING", "options": {}},
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
},
)
"options": {},
}
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
observation_space_cfg = None
reward_function_cfg = {}
# pa = ProbabilisticAgent(
# agent_name="test_agent",
# action_space=action_space,
@@ -67,9 +73,10 @@ def test_probabilistic_agent():
pa_config = {
"agent_name": "test_agent",
"action_manager": action_space,
"observation_manager": observation_space,
"reward_function": reward_function,
"game": PrimaiteGame(),
"action_manager": action_space_cfg,
"observation_manager": observation_space_cfg,
"reward_function": reward_function_cfg,
"agent_settings": {
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
},