#2869 - Commit before changing branches. Addition of properties to Agent classes and removal of if/else chain in game.py

This commit is contained in:
Charlie Crane
2024-11-21 14:45:35 +00:00
parent 75d4ef2dfd
commit 7435a4dee8
3 changed files with 45 additions and 121 deletions

View File

@@ -96,11 +96,6 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
"""Base class for scripted and RL agents."""
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
action_manager: Optional[ActionManager]
observation_manager: Optional[ObservationManager]
reward_function: Optional[RewardFunction]
config: "AbstractAgent.ConfigSchema"
class ConfigSchema(BaseModel):
@@ -121,39 +116,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
type: str
agent_name: ClassVar[str]
agent_settings = Optional[AgentSettings] = None
history: List[AgentHistoryItem] = []
logger: AgentLog = AgentLog(agent_name)
# def __init__(
# self,
# agent_name: Optional[str],
# action_space: Optional[ActionManager],
# observation_space: Optional[ObservationManager],
# reward_function: Optional[RewardFunction],
# agent_settings: Optional[AgentSettings] = None,
# ) -> None:
# """
# Initialize an agent.
# :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
# :type agent_name: Optional[str]
# :param action_space: Action space for the agent.
# :type action_space: Optional[ActionManager]
# :param observation_space: Observation space for the agent.
# :type observation_space: Optional[ObservationSpace]
# :param reward_function: Reward function for the agent.
# :type reward_function: Optional[RewardFunction]
# :param agent_settings: Configurable Options for Abstracted Agents
# :type agent_settings: Optional[AgentSettings]
# """
# self.agent_name: str = agent_name or "unnamed_agent"
# self.action_manager: Optional[ActionManager] = action_space
# self.observation_manager: Optional[ObservationManager] = observation_space
# self.reward_function: Optional[RewardFunction] = reward_function
# self.agent_settings = agent_settings or AgentSettings()
# self.history: List[AgentHistoryItem] = []
# self.logger = AgentLog(agent_name)
history: List[AgentHistoryItem] = []
action_manager: Optional[ActionManager] = None
observation_manager: Optional[ObservationManager] = None
reward_function: Optional[RewardFunction] = None
agent_settings = Optional[AgentSettings] = None
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
@@ -161,16 +129,25 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
@property
def observation_manager(self) -> ObservationManager:
"""Returns the agents observation manager."""
return self.config.observation_manager
@property
def action_manager(self) -> ActionManager:
"""Returns the agents action manager."""
return self.config.action_manager
@property
def reward_function(self) -> RewardFunction:
"""Returns the agents reward function."""
return self.config.reward_function
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
obj = cls(config=cls.ConfigSchema(**config))
# Pull managers out of config section for ease of use (?)
obj.observation_manager = obj.config.observation_manager
obj.action_manager = obj.config.action_manager
obj.reward_function = obj.config.reward_function
return obj
def update_observation(self, state: Dict) -> ObsType:

View File

@@ -10,9 +10,6 @@ from primaite.game.agent.interface import AbstractScriptedAgent
class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation_Agent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
next_execution_timestep: int = 0
starting_node_idx: int = 0
config: "DataManipulationAgent.ConfigSchema"
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
@@ -27,13 +24,23 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation
super().__init__(*args, **kwargs)
self.setup_agent()
@property
def next_execution_timestep(self):
"""Returns the agents next execution timestep."""
return self.config.next_execution_timestep
@property
def starting_node_name(self):
"""Returns the agents starting node name."""
return self.config.starting_node_name
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
@@ -48,11 +55,11 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
self.logger.debug(msg="Performing do nothing action")
self.config.logger.debug(msg="Performing do nothing action")
return "do_nothing", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
self.logger.info(msg="Performing a data manipulation attack!")
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
self.config.logger.info(msg="Performing a data manipulation attack!")
return "node_application_execute", {
"node_name": self.config.starting_node_name,
"application_name": self.config.starting_application_name,
@@ -68,4 +75,4 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)
self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
self.config.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")

View File

@@ -11,10 +11,6 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
@@ -178,7 +174,7 @@ class PrimaiteGame:
obs = agent.observation_manager.current_observation
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
if SIM_OUTPUT.save_agent_logs:
agent.logger.debug(f"Chosen Action: {action_choice}")
agent.config.logger.debug(f"Chosen Action: {action_choice}")
request = agent.format_request(action_choice, parameters)
response = self.simulation.apply_request(request)
agent.process_action_response(
@@ -548,77 +544,21 @@ class PrimaiteGame:
# CREATE AGENT
# TODO: MAKE THIS BIT WORK AND NOT THE IF/ELSE CHAIN OF HORRORS
# Pass through:
# config
# action manager
# observation_manager
# reward_function
agent_config = agent_cfg.get("agent_settings", {})
agent_config.update({"action_manager": action_space,
"observation_manager": obs_space,
"reward_function":reward_function})
agent_config.update(
{"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function}
)
# new_agent_cfg.update{}
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents
if agent_type == "ProxyAgent":
game.rl_agents[agent_cfg["ref"]] = new_agent
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
settings = agent_cfg.get("agent_settings", {})
new_agent = ProbabilisticAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
elif agent_type == "PeriodicAgent":
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
new_agent = PeriodicAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
elif agent_type == "ProxyAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
elif agent_type == "TAP001":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = TAP001(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
if agent_type in AbstractAgent._registry:
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents
if agent_type == "ProxyAgent":
game.rl_agents[agent_cfg["ref"]] = new_agent
else:
msg = f"Configuration error: {agent_type} is not a valid agent type."
_LOGGER.error(msg)
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.