#2869 - Fixed failing tests from agent refactor. Some tests still fail but this is due to updating some action names in anticipation of merging in the extensible actions refactor

This commit is contained in:
Charlie Crane
2024-12-17 10:51:57 +00:00
parent a4fbd29bb4
commit 436a986458
4 changed files with 25 additions and 13 deletions

View File

@@ -15,13 +15,18 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
config: "AbstractTAPAgent.ConfigSchema"
agent_name: str = "Abstract_TAP"
_next_execution_timestep: int
next_execution_timestep: int = 0
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Abstract TAP agents."""
starting_node_name: Optional[str] = None
# @property
# def next_execution_timestep(self) -> int:
# """Returns the agents next execution timestep."""
# return self.next_execution_timestep
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Return an action to be taken in the environment."""
@@ -40,7 +45,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
random_timestep_increment = random.randint(
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
)
self._next_execution_timestep = timestep + random_timestep_increment
self.next_execution_timestep = timestep + random_timestep_increment
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""

View File

@@ -14,18 +14,14 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingA
class ConfigSchema(AbstractTAPAgent.ConfigSchema):
"""Configuration Schema for DataManipulationAgent."""
starting_application_name: Optional[str] = None
# def __init__(self, **kwargs: Any) -> None:
# """Initialise DataManipulationAgent."""
# # self.setup_agent()
# self.setup_agent()
# super().__init_subclass__(**kwargs)
@property
def next_execution_timestep(self) -> int:
"""Returns the agents next execution timestep."""
return self._next_execution_timestep
@property
def starting_node_name(self) -> str:
"""Returns the agents starting node name."""

View File

@@ -135,6 +135,15 @@ class AbstractAgent(BaseModel):
"""Return the AgentLog."""
return self.config._logger
@property
def flatten_obs(self) -> bool:
return self.config.agent_settings.flatten_obs
@property
def history(self) -> List[AgentHistoryItem]:
"""Return the agent history"""
return self.config.history
@property
def observation_manager(self) -> ObservationManager:
"""Returns the agents observation manager."""
@@ -236,7 +245,7 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
"""Agent that sends observations to an RL model and receives actions from that model."""
config: "ProxyAgent.ConfigSchema"
_most_recent_action: ActType
most_recent_action: ActType = None
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Proxy Agent."""
@@ -246,10 +255,10 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
action_masking: bool = agent_settings.action_masking if agent_settings else False
@property
def most_recent_action(self) -> ActType:
"""Convenience method to access the agents most recent action."""
return self._most_recent_action
# @property
# def most_recent_action(self) -> ActType:
# """Convenience method to access the agents most recent action."""
# return self._most_recent_action
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -47,6 +47,8 @@ agents:
start_step: 25
frequency: 20
variance: 5
action_probabilities:
0: 1.0
- ref: data_manipulation_attacker
team: RED