#2869 - eod commit. Updates to AbstractAgent.from_config, and some minor tweaks to PrimaiteGame
This commit is contained in:
@@ -97,11 +97,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
action_manager: Optional[ActionManager]
|
||||
observation_manager: Optional[ObservationManager]
|
||||
reward_function: Optional[RewardFunction]
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""
|
||||
Configuration Schema for AbstractAgents.
|
||||
@@ -163,7 +164,14 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
"""Creates an agent component from a configuration dictionary."""
|
||||
return cls(config=cls.ConfigSchema(**config))
|
||||
obj = cls(config=cls.ConfigSchema(**config))
|
||||
|
||||
# Pull managers out of config section for ease of use (?)
|
||||
obj.observation_manager = obj.config.observation_manager
|
||||
obj.action_manager = obj.config.action_manager
|
||||
obj.reward_function = obj.config.reward_function
|
||||
|
||||
return obj
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -172,7 +180,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.config.observation_manager.update(state)
|
||||
return self.observation_manager.update(state)
|
||||
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
@@ -183,7 +191,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.config.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
return self.reward_function.update(state=state, last_action_response=self.config.history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -201,13 +209,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
return ("do_nothing", {})
|
||||
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.config.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def process_action_response(
|
||||
@@ -222,7 +230,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
|
||||
def save_reward_to_history(self) -> None:
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.config.history[-1].reward = self.config.reward_function.current_reward
|
||||
self.config.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
|
||||
|
||||
@@ -15,7 +15,7 @@ from primaite.game.agent.rewards import RewardFunction
|
||||
class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
class Settings(pydantic.BaseModel):
|
||||
class ConfigSchema(pydantic.BaseModel):
|
||||
"""Config schema for Probabilistic agent settings."""
|
||||
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
@@ -60,7 +60,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
self.settings = ProbabilisticAgent.ConfigSchema(**settings)
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class TAP001(AbstractScriptedAgent, identifier="TAP001"):
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
# TODO: Link with DataManipulationAgent via a parent "TAP" agent class.
|
||||
# TODO: Link with DataManipulationAgent class via a parent "TAP" agent class.
|
||||
|
||||
config: "TAP001.ConfigSchema"
|
||||
|
||||
|
||||
@@ -555,17 +555,16 @@ class PrimaiteGame:
|
||||
# action manager
|
||||
# observation_manager
|
||||
# reward_function
|
||||
|
||||
new_agent_cfg = {
|
||||
"action_manager": action_space,
|
||||
"agent_name": agent_cfg["ref"],
|
||||
"observation_manager": obs_space,
|
||||
"agent_settings": agent_cfg.get("agent_settings", {}),
|
||||
"reward_function": reward_function,
|
||||
}
|
||||
new_agent_cfg = agent_cfg["settings"]
|
||||
agent_config = agent_cfg.get("agent_settings", {})
|
||||
agent_config.update({"action_manager": action_space,
|
||||
"observation_manager": obs_space,
|
||||
"reward_function":reward_function})
|
||||
# new_agent_cfg.update{}
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=new_agent_cfg)
|
||||
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
|
||||
|
||||
# If blue agent is created, add to game.rl_agents
|
||||
if agent_type == "ProxyAgent":
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
|
||||
Reference in New Issue
Block a user