Update data manipulation bot

This commit is contained in:
Jake Walker
2023-11-24 10:33:19 +00:00
41 changed files with 3648 additions and 499 deletions

View File

@@ -4,10 +4,11 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
import numpy as np
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
@@ -55,7 +56,7 @@ class AbstractAgent(ABC):
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings],
) -> None:
@@ -72,21 +73,21 @@ class AbstractAgent(ABC):
:type reward_function: Optional[RewardFunction]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
def convert_state_to_obs(self, state: Dict) -> ObsType:
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
return self.observation_manager.update(state)
def calculate_reward_from_state(self, state: Dict) -> float:
def update_reward(self, state: Dict) -> float:
"""
Use the reward function to calculate a reward from the state.
@@ -95,10 +96,10 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.calculate(state)
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -111,7 +112,7 @@ class AbstractAgent(ABC):
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
@@ -119,7 +120,7 @@ class AbstractAgent(ABC):
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_space.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
@@ -132,7 +133,7 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -142,7 +143,47 @@ class RandomAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
return self.action_manager.get_action(self.action_manager.space.sample())
class ProxyAgent(AbstractAgent):
"""Agent that sends observations to an RL model and receives actions from that model."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""
Store the most recent action taken by the agent.
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action
class DataManipulationAgent(AbstractScriptedAgent):