#2869 - Commit before changing branches. Addition of properties to Agent classes and removal of if/else chain in game.py
This commit is contained in:
@@ -96,11 +96,6 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
action_manager: Optional[ActionManager]
|
||||
observation_manager: Optional[ObservationManager]
|
||||
reward_function: Optional[RewardFunction]
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
@@ -121,39 +116,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
|
||||
type: str
|
||||
agent_name: ClassVar[str]
|
||||
agent_settings = Optional[AgentSettings] = None
|
||||
history: List[AgentHistoryItem] = []
|
||||
logger: AgentLog = AgentLog(agent_name)
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# agent_name: Optional[str],
|
||||
# action_space: Optional[ActionManager],
|
||||
# observation_space: Optional[ObservationManager],
|
||||
# reward_function: Optional[RewardFunction],
|
||||
# agent_settings: Optional[AgentSettings] = None,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Initialize an agent.
|
||||
|
||||
# :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
# :type agent_name: Optional[str]
|
||||
# :param action_space: Action space for the agent.
|
||||
# :type action_space: Optional[ActionManager]
|
||||
# :param observation_space: Observation space for the agent.
|
||||
# :type observation_space: Optional[ObservationSpace]
|
||||
# :param reward_function: Reward function for the agent.
|
||||
# :type reward_function: Optional[RewardFunction]
|
||||
# :param agent_settings: Configurable Options for Abstracted Agents
|
||||
# :type agent_settings: Optional[AgentSettings]
|
||||
# """
|
||||
# self.agent_name: str = agent_name or "unnamed_agent"
|
||||
# self.action_manager: Optional[ActionManager] = action_space
|
||||
# self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
# self.reward_function: Optional[RewardFunction] = reward_function
|
||||
# self.agent_settings = agent_settings or AgentSettings()
|
||||
# self.history: List[AgentHistoryItem] = []
|
||||
# self.logger = AgentLog(agent_name)
|
||||
history: List[AgentHistoryItem] = []
|
||||
action_manager: Optional[ActionManager] = None
|
||||
observation_manager: Optional[ObservationManager] = None
|
||||
reward_function: Optional[RewardFunction] = None
|
||||
agent_settings = Optional[AgentSettings] = None
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -161,16 +129,25 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@property
|
||||
def observation_manager(self) -> ObservationManager:
|
||||
"""Returns the agents observation manager."""
|
||||
return self.config.observation_manager
|
||||
|
||||
@property
|
||||
def action_manager(self) -> ActionManager:
|
||||
"""Returns the agents action manager."""
|
||||
return self.config.action_manager
|
||||
|
||||
@property
|
||||
def reward_function(self) -> RewardFunction:
|
||||
"""Returns the agents reward function."""
|
||||
return self.config.reward_function
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
"""Creates an agent component from a configuration dictionary."""
|
||||
obj = cls(config=cls.ConfigSchema(**config))
|
||||
|
||||
# Pull managers out of config section for ease of use (?)
|
||||
obj.observation_manager = obj.config.observation_manager
|
||||
obj.action_manager = obj.config.action_manager
|
||||
obj.reward_function = obj.config.reward_function
|
||||
|
||||
return obj
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
|
||||
@@ -10,9 +10,6 @@ from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation_Agent"):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
|
||||
config: "DataManipulationAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
@@ -27,13 +24,23 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
|
||||
@property
|
||||
def next_execution_timestep(self):
|
||||
"""Returns the agents next execution timestep."""
|
||||
return self.config.next_execution_timestep
|
||||
|
||||
@property
|
||||
def starting_node_name(self):
|
||||
"""Returns the agents starting node name."""
|
||||
return self.config.starting_node_name
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
@@ -48,11 +55,11 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
self.config.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency)
|
||||
self.config.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "node_application_execute", {
|
||||
"node_name": self.config.starting_node_name,
|
||||
"application_name": self.config.starting_application_name,
|
||||
@@ -68,4 +75,4 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
|
||||
self.config.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
|
||||
|
||||
@@ -11,10 +11,6 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.creation import NetworkNodeAdder
|
||||
@@ -178,7 +174,7 @@ class PrimaiteGame:
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
|
||||
if SIM_OUTPUT.save_agent_logs:
|
||||
agent.logger.debug(f"Chosen Action: {action_choice}")
|
||||
agent.config.logger.debug(f"Chosen Action: {action_choice}")
|
||||
request = agent.format_request(action_choice, parameters)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent.process_action_response(
|
||||
@@ -548,77 +544,21 @@ class PrimaiteGame:
|
||||
|
||||
# CREATE AGENT
|
||||
|
||||
# TODO: MAKE THIS BIT WORK AND NOT THE IF/ELSE CHAIN OF HORRORS
|
||||
|
||||
# Pass through:
|
||||
# config
|
||||
# action manager
|
||||
# observation_manager
|
||||
# reward_function
|
||||
agent_config = agent_cfg.get("agent_settings", {})
|
||||
agent_config.update({"action_manager": action_space,
|
||||
"observation_manager": obs_space,
|
||||
"reward_function":reward_function})
|
||||
agent_config.update(
|
||||
{"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function}
|
||||
)
|
||||
# new_agent_cfg.update{}
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
|
||||
|
||||
# If blue agent is created, add to game.rl_agents
|
||||
if agent_type == "ProxyAgent":
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
settings = agent_cfg.get("agent_settings", {})
|
||||
new_agent = ProbabilisticAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
elif agent_type == "PeriodicAgent":
|
||||
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
|
||||
new_agent = PeriodicAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
elif agent_type == "ProxyAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
elif agent_type == "TAP001":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = TAP001(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
if agent_type in AbstractAgent._registry:
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
|
||||
# If blue agent is created, add to game.rl_agents
|
||||
if agent_type == "ProxyAgent":
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
else:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
|
||||
Reference in New Issue
Block a user