#2869 - pre-commit changes

This commit is contained in:
Charlie Crane
2024-12-17 12:50:14 +00:00
parent 770896200b
commit dc6f2be209
7 changed files with 33 additions and 28 deletions

View File

@@ -7,9 +7,10 @@ from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractAgent, AbstractScriptedAgent
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
__all__ = "AbstractTAPAgent"
__all__ = ("AbstractTAPAgent")
class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
"""Base class for TAP agents to inherit from."""

View File

@@ -1,11 +1,12 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Any, Dict, Optional, Tuple
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent
__all__ = ("DataManipulationAgent")
__all__ = "DataManipulationAgent"
class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingAgent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""

View File

@@ -17,11 +17,8 @@ from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
__all__ = ("AgentHistoryItem",
"AgentStartSettings",
"AbstractAgent",
"AbstractScriptedAgent",
"ProxyAgent")
__all__ = ("AgentHistoryItem", "AgentStartSettings", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
class AgentHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
@@ -121,7 +118,7 @@ class AbstractAgent(BaseModel):
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
agent_name: ClassVar[str] = "Abstract_Agent" # TODO: Make this a ClassVar[str] like verb in actions?
agent_name: ClassVar[str] = "Abstract_Agent" # TODO: Make this a ClassVar[str] like verb in actions?
history: List[AgentHistoryItem] = []
_logger: AgentLog = AgentLog(agent_name=agent_name)
action_manager: ActionManager
@@ -142,11 +139,12 @@ class AbstractAgent(BaseModel):
@property
def flatten_obs(self) -> bool:
"""Return agent flatten_obs param."""
return self.config.agent_settings.flatten_obs
@property
def history(self) -> List[AgentHistoryItem]:
"""Return the agent history"""
"""Return the agent history."""
return self.config.history
@property

View File

@@ -8,7 +8,8 @@ from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings
__all__ = ("ProbabilisticAgent")
__all__ = "ProbabilisticAgent"
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
@@ -17,8 +18,11 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
class AgentSettings(AgentSettings):
"""ProbabilisticAgent settings."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
@@ -44,7 +48,6 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
agent_name: str = "ProbabilisticAgent"
agent_settings: "ProbabilisticAgent.AgentSettings"
@property
def probabilities(self) -> Dict[str, int]:
"""Convenience method to view the probabilities of the Agent."""

View File

@@ -8,6 +8,7 @@ from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
__all__ = ("RandomAgent", "PeriodicAgent")
class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):
"""Agent that ignores its observation and acts completely at random."""
@@ -45,11 +46,11 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
"The timestep at which an agent begins performing it's actions."
start_variance: int = 5
"Deviation around the start step."
# TODO: This is available in config.agent_settings.start_settings.frequency
frequency: int = 5
"The number of timesteps to wait between performing actions."
# TODO: This is available in config.agent_settings.start_settings.variance
variance: int = 0
"The amount the frequency can randomly change to."
@@ -57,7 +58,7 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"):
"Maximum number of times the agent can execute its action."
num_executions: int = 0
"""Number of times the agent has executed an action."""
#TODO: Also in abstract_tap - move up and inherit? Add to AgentStartSettings?
# TODO: Also in abstract_tap - move up and inherit? Add to AgentStartSettings?
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""

View File

@@ -497,12 +497,11 @@ def game_and_agent():
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
config = {
"agent_name":"test_agent",
"action_manager":action_space,
"observation_manager":observation_space,
"reward_function":reward_function,
"agent_name": "test_agent",
"action_manager": action_space,
"observation_manager": observation_space,
"reward_function": reward_function,
}
test_agent = ControlledAgent.from_config(config=config)

View File

@@ -65,13 +65,15 @@ def test_probabilistic_agent():
# },
# )
pa_config = {"agent_name":"test_agent",
"action_manager": action_space,
"observation_manager": observation_space,
"reward_function": reward_function,
"agent_settings": {
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
}}
pa_config = {
"agent_name": "test_agent",
"action_manager": action_space,
"observation_manager": observation_space,
"reward_function": reward_function,
"agent_settings": {
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
},
}
pa = ProbabilisticAgent.from_config(config=pa_config)