#2869 - eod commit. Updates to AbstractAgent.from_config, and some minor tweaks to PrimaiteGame

This commit is contained in:
Charlie Crane
2024-11-20 17:51:05 +00:00
parent a3dc616126
commit 75d4ef2dfd
4 changed files with 27 additions and 20 deletions

View File

@@ -97,11 +97,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
config: "AbstractAgent.ConfigSchema"
action_manager: Optional[ActionManager]
observation_manager: Optional[ObservationManager]
reward_function: Optional[RewardFunction]
config: "AbstractAgent.ConfigSchema"
class ConfigSchema(BaseModel):
"""
Configuration Schema for AbstractAgents.
@@ -163,7 +164,14 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
return cls(config=cls.ConfigSchema(**config))
obj = cls(config=cls.ConfigSchema(**config))
# Pull managers out of config section for ease of use (?)
obj.observation_manager = obj.config.observation_manager
obj.action_manager = obj.config.action_manager
obj.reward_function = obj.config.reward_function
return obj
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -172,7 +180,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.config.observation_manager.update(state)
return self.observation_manager.update(state)
def update_reward(self, state: Dict) -> float:
"""
@@ -183,7 +191,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
:return: Reward from the state.
:rtype: float
"""
return self.config.reward_function.update(state=state, last_action_response=self.history[-1])
return self.reward_function.update(state=state, last_action_response=self.config.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -201,13 +209,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
"""
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
return ("do_nothing", {})
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.config.action_manager.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
def process_action_response(
@@ -222,7 +230,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"):
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.config.history[-1].reward = self.config.reward_function.current_reward
self.config.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):