Update data manipulation bot
This commit is contained in:
@@ -4,10 +4,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -55,7 +56,7 @@ class AbstractAgent(ABC):
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings],
|
||||
) -> None:
|
||||
@@ -72,21 +73,21 @@ class AbstractAgent(ABC):
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
return self.observation_manager.update(state)
|
||||
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
Use the reward function to calculate a reward from the state.
|
||||
|
||||
@@ -95,10 +96,10 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.calculate(state)
|
||||
return self.reward_function.update(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
@@ -111,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
@@ -119,7 +120,7 @@ class AbstractAgent(ABC):
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
@@ -132,7 +133,7 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
@@ -142,7 +143,47 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.most_recent_action)
|
||||
|
||||
def store_action(self, action: ActType):
|
||||
"""
|
||||
Store the most recent action taken by the agent.
|
||||
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
Reference in New Issue
Block a user