#2869 - eod commit. Updates to AbstractAgent.from_config, and some minor tweaks to PrimaiteGame
This commit is contained in:
@@ -97,11 +97,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
action_manager: Optional[ActionManager]
|
||||
observation_manager: Optional[ObservationManager]
|
||||
reward_function: Optional[RewardFunction]
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""
|
||||
Configuration Schema for AbstractAgents.
|
||||
@@ -163,7 +164,14 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
"""Creates an agent component from a configuration dictionary."""
|
||||
return cls(config=cls.ConfigSchema(**config))
|
||||
obj = cls(config=cls.ConfigSchema(**config))
|
||||
|
||||
# Pull managers out of config section for ease of use (?)
|
||||
obj.observation_manager = obj.config.observation_manager
|
||||
obj.action_manager = obj.config.action_manager
|
||||
obj.reward_function = obj.config.reward_function
|
||||
|
||||
return obj
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -172,7 +180,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.config.observation_manager.update(state)
|
||||
return self.observation_manager.update(state)
|
||||
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
@@ -183,7 +191,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.config.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
return self.reward_function.update(state=state, last_action_response=self.config.history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -201,13 +209,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
return ("do_nothing", {})
|
||||
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.config.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def process_action_response(
|
||||
@@ -222,7 +230,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
|
||||
|
||||
def save_reward_to_history(self) -> None:
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.config.history[-1].reward = self.config.reward_function.current_reward
|
||||
self.config.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
|
||||
|
||||
Reference in New Issue
Block a user