#2869 - eod commit. Updates to AbstractAgent.from_config, and some minor tweaks to PrimaiteGame

This commit is contained in:
Charlie Crane
2024-11-20 17:51:05 +00:00
parent a3dc616126
commit 75d4ef2dfd
4 changed files with 27 additions and 20 deletions

View File

@@ -97,11 +97,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
config: "AbstractAgent.ConfigSchema"
action_manager: Optional[ActionManager]
observation_manager: Optional[ObservationManager]
reward_function: Optional[RewardFunction]
config: "AbstractAgent.ConfigSchema"
class ConfigSchema(BaseModel):
"""
Configuration Schema for AbstractAgents.
@@ -163,7 +164,14 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
return cls(config=cls.ConfigSchema(**config))
obj = cls(config=cls.ConfigSchema(**config))
# Pull managers out of config section for ease of use (?)
obj.observation_manager = obj.config.observation_manager
obj.action_manager = obj.config.action_manager
obj.reward_function = obj.config.reward_function
return obj
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -172,7 +180,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.config.observation_manager.update(state)
return self.observation_manager.update(state)
def update_reward(self, state: Dict) -> float:
"""
@@ -183,7 +191,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
:return: Reward from the state.
:rtype: float
"""
return self.config.reward_function.update(state=state, last_action_response=self.history[-1])
return self.reward_function.update(state=state, last_action_response=self.config.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -201,13 +209,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
"""
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
return ("do_nothing", {})
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.config.action_manager.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
def process_action_response(
@@ -222,7 +230,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.config.history[-1].reward = self.config.reward_function.current_reward
self.config.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):

View File

@@ -15,7 +15,7 @@ from primaite.game.agent.rewards import RewardFunction
class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
class Settings(pydantic.BaseModel):
class ConfigSchema(pydantic.BaseModel):
"""Config schema for Probabilistic agent settings."""
model_config = pydantic.ConfigDict(extra="forbid")
@@ -60,7 +60,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
self.settings = ProbabilisticAgent.ConfigSchema(**settings)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)

View File

@@ -16,7 +16,7 @@ class TAP001(AbstractScriptedAgent, identifier="TAP001"):
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
"""
# TODO: Link with DataManipulationAgent via a parent "TAP" agent class.
# TODO: Link with DataManipulationAgent class via a parent "TAP" agent class.
config: "TAP001.ConfigSchema"

View File

@@ -555,17 +555,16 @@ class PrimaiteGame:
# action manager
# observation_manager
# reward_function
new_agent_cfg = {
"action_manager": action_space,
"agent_name": agent_cfg["ref"],
"observation_manager": obs_space,
"agent_settings": agent_cfg.get("agent_settings", {}),
"reward_function": reward_function,
}
new_agent_cfg = agent_cfg["settings"]
agent_config = agent_cfg.get("agent_settings", {})
agent_config.update({"action_manager": action_space,
"observation_manager": obs_space,
"reward_function":reward_function})
# new_agent_cfg.update{}
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=new_agent_cfg)
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents
if agent_type == "ProxyAgent":
game.rl_agents[agent_cfg["ref"]] = new_agent
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing