Update agent interface to work better with envs

This commit is contained in:
Marek Wolan
2023-11-14 15:10:07 +00:00
parent 1cb54da2dd
commit e6ead6e532
7 changed files with 240 additions and 97 deletions

View File

@@ -4,8 +4,12 @@ training_config:
seed: 333
n_learn_episodes: 20
n_learn_steps: 128
n_eval_episodes: 20
n_eval_episodes: 5
n_eval_steps: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game_config:
@@ -108,7 +112,7 @@ game_config:
- ref: defender
team: BLUE
type: idk???
type: RLAgent
observation_space:
type: UC2BlueObservation

View File

@@ -1,15 +1,13 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
from gymnasium.core import ActType, ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
ObsType: TypeAlias = Union[Dict, np.ndarray]
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
@@ -18,7 +16,7 @@ class AbstractAgent(ABC):
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
"""
@@ -34,24 +32,24 @@ class AbstractAgent(ABC):
:type reward_function: Optional[RewardFunction]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
def convert_state_to_obs(self, state: Dict) -> ObsType:
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
return self.observation_manager.update(state)
def calculate_reward_from_state(self, state: Dict) -> float:
def update_reward(self, state: Dict) -> float:
"""
Use the reward function to calculate a reward from the state.
@@ -60,10 +58,10 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.calculate(state)
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -84,7 +82,7 @@ class AbstractAgent(ABC):
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_space.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
@@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -107,4 +105,44 @@ class RandomAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
return self.action_manager.get_action(self.action_manager.space.sample())
class ProxyAgent(AbstractAgent):
"""Agent that sends observations to an RL model and receives actions from that model."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""
Store the most recent action taken by the agent.
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
pass
class ObservationSpace:
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -947,15 +948,17 @@ class ObservationSpace:
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.current_observation: ObsType
def observe(self, state: Dict) -> Dict:
def update(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
return self.obs.observe(state)
self.current_observation = self.obs.observe(state)
return self.current_observation
@property
def space(self) -> None:
@@ -963,7 +966,7 @@ class ObservationSpace:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.

View File

@@ -238,6 +238,7 @@ class RewardFunction:
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.
@@ -249,7 +250,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def calculate(self, state: Dict) -> float:
def update(self, state: Dict) -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -260,7 +261,8 @@ class RewardFunction:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
return total
self.current_reward = total
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":

View File

@@ -1,18 +1,47 @@
from abc import ABC, abstractclassmethod, abstractmethod
from typing import TYPE_CHECKING
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from typing import Any, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, name: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if name in cls._registry:
raise ValueError(f"Duplicate policy name {name}")
cls._registry[name] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""Initialize a reinforcement learning agent."""
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
pass
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, n_time_steps: int) -> None:
@@ -25,30 +54,30 @@ class PolicyABC(ABC):
pass
@abstractmethod
def save(
self,
) -> None:
def save(self) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(
self,
) -> None:
def load(self) -> None:
"""Load agent from a file."""
pass
def close(
self,
) -> None:
def close(self) -> None:
"""Close the agent."""
pass
@abstractclassmethod
def from_config(
cls,
) -> "PolicyABC":
"""Create an agent from a config file."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config()
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -1,4 +1,5 @@
from typing import Literal, TYPE_CHECKING, Union
"""Stable baselines 3 policy."""
from typing import Literal, Optional, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
@@ -7,13 +8,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"]):
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
@@ -29,8 +30,8 @@ class SB3Policy(PolicyABC):
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=...,
seed=...,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
) # TODO: populate values once I figure out how to get them from the config / session
def learn(self, n_episodes: int, n_time_steps: int) -> None:
@@ -68,6 +69,6 @@ class SB3Policy(PolicyABC):
pass
@classmethod
def from_config(self) -> "SB3Policy":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
pass
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,13 +1,15 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
@@ -31,6 +33,58 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self.agent.observation_manager.current_observation
reward = self.agent.reward_function.current_reward
terminated = False
truncated = ...
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self.agent.observation_manager.current_observation
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.action_space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return self.agent.observation_manager.observation_space
class PrimaiteSessionOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
@@ -45,28 +99,29 @@ class PrimaiteSessionOptions(BaseModel):
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
rl_framework: str
rl_algorithm: str
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
seed: Optional[int]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int
n_eval_steps: int
n_eval_episodes: int = 0
n_eval_steps: Optional[int] = None
deterministic_eval: bool
n_agents: int
agent_references: List[str]
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
def __init__(self):
"""Initialise a PrimaiteSession object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
# self.rl_agent: AbstractAgent
# """The agent from the list which communicates with GATE to perform reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
@@ -94,8 +149,10 @@ class PrimaiteSession:
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
# self.env:
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
"""Commence the training session."""
n_learn_steps = self.training_options.n_learn_steps
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_steps = self.training_options.n_eval_steps
@@ -119,40 +176,47 @@ class PrimaiteSession:
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
# 5. primaite session asks agents to produce an action based on most recent state
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
sim_state = self.simulation.describe_state()
agent.update_observation(state)
agent.update_reward(state)
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.convert_state_to_obs(sim_state)
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.calculate_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
_LOGGER.debug("Getting agent action")
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
self.simulation.apply_request(agent_request)
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
def advance_timestep(self) -> None:
"""Advance timestep."""
self.simulation.apply_timestep(self.step_counter)
self.step_counter += 1
@@ -161,7 +225,7 @@ class PrimaiteSession:
return NotImplemented
def close(self) -> None:
"""Close the session, this will stop the gate client and close the simulation."""
"""Close the session, this will stop the env and close the simulation."""
return NotImplemented
@classmethod
@@ -169,7 +233,7 @@ class PrimaiteSession:
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent. Used by GATE.
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
@@ -323,7 +387,7 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
@@ -359,15 +423,14 @@ class PrimaiteSession:
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
new_agent = RandomAgent(
elif agent_type == "RLAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agent = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
@@ -379,4 +442,7 @@ class PrimaiteSession:
else:
print("agent type not found")
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options)
return sess