diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 676028bb..0c39333c 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -4,8 +4,12 @@ training_config: seed: 333 n_learn_episodes: 20 n_learn_steps: 128 - n_eval_episodes: 20 + n_eval_episodes: 5 n_eval_steps: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender game_config: @@ -108,7 +112,7 @@ game_config: - ref: defender team: BLUE - type: idk??? + type: RLAgent observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index e3b98777..75d209ce 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,15 +1,13 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Dict, List, Optional, Tuple -import numpy as np +from gymnasium.core import ActType, ObsType from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -ObsType: TypeAlias = Union[Dict, np.ndarray] - class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -18,7 +16,7 @@ class AbstractAgent(ABC): self, agent_name: Optional[str], action_space: Optional[ActionManager], - observation_space: Optional[ObservationSpace], + observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], ) -> None: """ @@ -34,24 +32,24 @@ class AbstractAgent(ABC): :type reward_function: Optional[RewardFunction] """ self.agent_name: str = agent_name or "unnamed_agent" - self.action_space: Optional[ActionManager] = action_space - self.observation_space: Optional[ObservationSpace] = observation_space + self.action_manager: Optional[ActionManager] = action_space + self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function # exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info # by for example specifying target ip addresses, or converting a node ID into a uuid self.execution_definition = None - def convert_state_to_obs(self, state: Dict) -> ObsType: + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ - return self.observation_space.observe(state) + return self.observation_manager.update(state) - def calculate_reward_from_state(self, state: Dict) -> float: + def update_reward(self, state: Dict) -> float: """ Use the reward function to calculate a reward from the state. @@ -60,10 +58,10 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.calculate(state) + return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -84,7 +82,7 @@ class AbstractAgent(ABC): # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - request = self.action_space.form_request(action_identifier=action, action_options=options) + request = self.action_manager.form_request(action_identifier=action, action_options=options) return request @@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -107,4 +105,44 @@ class RandomAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - return self.action_space.get_action(self.action_space.space.sample()) + return self.action_manager.get_action(self.action_manager.space.sample()) + + +class ProxyAgent(AbstractAgent): + """Agent that sends observations to an RL model and receives actions from that model.""" + + def __init__( + self, + agent_name: Optional[str], + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + ) -> None: + super().__init__( + agent_name=agent_name, + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + self.most_recent_action: ActType + + def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + """ + Return the agent's most recent action, formatted in CAOS format. + + :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. + :type obs: ObsType + :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. + :type reward: float, optional + :return: Action to be taken in CAOS format. + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.most_recent_action) + + def store_action(self, action: ActType): + """ + Store the most recent action taken by the agent. + + The environment is responsible for calling this method when it receives an action from the agent policy. + """ + self.most_recent_action = action diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a3bafeea..a74771c0 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces +from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation): pass -class ObservationSpace: +class ObservationManager: """ Manage the observations of an Agent. @@ -947,15 +948,17 @@ class ObservationSpace: :type observation: AbstractObservation """ self.obs: AbstractObservation = observation + self.current_observation: ObsType - def observe(self, state: Dict) -> Dict: + def update(self, state: Dict) -> Dict: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary :type state: Dict """ - return self.obs.observe(state) + self.current_observation = self.obs.observe(state) + return self.current_observation @property def space(self) -> None: @@ -963,7 +966,7 @@ class ObservationSpace: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace": + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 6c408ff9..49d56e67 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -238,6 +238,7 @@ class RewardFunction: """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] "attribute reward_components keeps track of reward components and the weights assigned to each." + self.current_reward: float def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None: """Add a reward component to the reward function. @@ -249,7 +250,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def calculate(self, state: Dict) -> float: + def update(self, state: Dict) -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -260,7 +261,8 @@ class RewardFunction: comp = comp_and_weight[0] weight = comp_and_weight[1] total += weight * comp.calculate(state=state) - return total + self.current_reward = total + return self.current_reward @classmethod def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 404d6f31..5669a4ff 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -1,18 +1,47 @@ -from abc import ABC, abstractclassmethod, abstractmethod -from typing import TYPE_CHECKING +"""Base class and common logic for RL policies.""" +from abc import ABC, abstractmethod +from typing import Any, Dict, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.session import PrimaiteSession, TrainingOptions class PolicyABC(ABC): """Base class for reinforcement learning agents.""" + _registry: Dict[str, type["PolicyABC"]] = {} + """ + Registry of policy types, keyed by name. + + Automatically populated when PolicyABC subclasses are defined. Used for defining from_config. + """ + + def __init_subclass__(cls, name: str, **kwargs: Any) -> None: + """ + Register a policy subclass. + + :param name: Identifier used by from_config to create an instance of the policy. + :type name: str + :raises ValueError: When attempting to create a policy with a duplicate name. + """ + super().__init_subclass__(**kwargs) + if name in cls._registry: + raise ValueError(f"Duplicate policy name {name}") + cls._registry[name] = cls + return + @abstractmethod def __init__(self, session: "PrimaiteSession") -> None: - """Initialize a reinforcement learning agent.""" + """ + Initialize a reinforcement learning policy. + + :param session: The session context. + :type session: PrimaiteSession + :param agents: The agents to train. + :type agents: List[RLAgent] + """ self.session: "PrimaiteSession" = session - pass + """Reference to the session.""" @abstractmethod def learn(self, n_episodes: int, n_time_steps: int) -> None: @@ -25,30 +54,30 @@ class PolicyABC(ABC): pass @abstractmethod - def save( - self, - ) -> None: + def save(self) -> None: """Save the agent.""" pass @abstractmethod - def load( - self, - ) -> None: + def load(self) -> None: """Load agent from a file.""" pass - def close( - self, - ) -> None: + def close(self) -> None: """Close the agent.""" pass - @abstractclassmethod - def from_config( - cls, - ) -> "PolicyABC": - """Create an agent from a config file.""" - pass + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": + """ + Create an RL policy from a config by calling the relevant subclass's from_config method. + + Subclasses should not call super().from_config(), they should just handle creation form config. + """ + # Assume that basically the contents of training_config are passed into here. + # I should really define a config schema class using pydantic. + + PolicyType = cls._registry[config.rl_framework] + return PolicyType.from_config() # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 2d9da1db..73df1b98 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,4 +1,5 @@ -from typing import Literal, TYPE_CHECKING, Union +"""Stable baselines 3 policy.""" +from typing import Literal, Optional, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP @@ -7,13 +8,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.session import PrimaiteSession, TrainingOptions class SB3Policy(PolicyABC): """Single agent RL policy using stable baselines 3.""" - def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"]): + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) @@ -29,8 +30,8 @@ class SB3Policy(PolicyABC): self._agent = self._agent_class( policy=policy, env=self.session.env, - n_steps=..., - seed=..., + n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch + seed=seed, ) # TODO: populate values once I figure out how to get them from the config / session def learn(self, n_episodes: int, n_time_steps: int) -> None: @@ -68,6 +69,6 @@ class SB3Policy(PolicyABC): pass @classmethod - def from_config(self) -> "SB3Policy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": """Create an agent from config file.""" - pass + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 9d241932..5556dd87 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,13 +1,15 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +import gymnasium +from gymnasium.core import ActType, ObsType from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, RandomAgent -from primaite.game.agent.observations import ObservationSpace +from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent +from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -31,6 +33,58 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) +class PrimaiteEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. + + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ + + def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + """Initialise the environment.""" + super().__init__() + self.session: "PrimaiteSession" = session + self.agent: ProxyAgent = agents[0] + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.session.apply_agent_actions() + self.session.advance_timestep() + state = self.session.get_sim_state() + self.session.update_agents(state) + + next_obs = self.agent.observation_manager.current_observation + reward = self.agent.reward_function.current_reward + terminated = False + truncated = ... + info = {} + + return next_obs, reward, terminated, truncated, info + + def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]: + """Reset the environment.""" + self.session.reset() + state = self.session.get_sim_state() + self.session.update_agents(state) + next_obs = self.agent.observation_manager.current_observation + info = {} + return next_obs, info + + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.action_space + + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return self.agent.observation_manager.observation_space + + class PrimaiteSessionOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -45,28 +99,29 @@ class PrimaiteSessionOptions(BaseModel): class TrainingOptions(BaseModel): """Options for training the RL agent.""" - rl_framework: str - rl_algorithm: str + rl_framework: Literal["SB3", "RLLIB"] + rl_algorithm: Literal["PPO", "A2C"] seed: Optional[int] n_learn_episodes: int n_learn_steps: int - n_eval_episodes: int - n_eval_steps: int + n_eval_episodes: int = 0 + n_eval_steps: Optional[int] = None + deterministic_eval: bool + n_agents: int + agent_references: List[str] class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE.""" + """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" def __init__(self): + """Initialise a PrimaiteSession object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" self.agents: List[AbstractAgent] = [] """List of agents.""" - # self.rl_agent: AbstractAgent - # """The agent from the list which communicates with GATE to perform reinforcement learning.""" - self.step_counter: int = 0 """Current timestep within the episode.""" @@ -94,8 +149,10 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" + # self.env: + def start_session(self) -> None: - """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" + """Commence the training session.""" n_learn_steps = self.training_options.n_learn_steps n_learn_episodes = self.training_options.n_learn_episodes n_eval_steps = self.training_options.n_eval_steps @@ -119,40 +176,47 @@ class PrimaiteSession: 4. Each agent chooses an action based on the observation. 5. Each agent converts the action to a request. 6. The simulation applies the requests. + + Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent + interacts with should implement a step method that calls methods used by this method. For example, if using a + single-agent gym, make sure to update the ProxyAgent's action with the action before calling + ``self.apply_agent_actions()``. """ _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") - # currently designed with assumption that all agents act once per step in order + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state + self.update_agents(sim_state) + + # Apply all actions to simulation as requests + self.apply_agent_actions() + + # Advance timestep + self.advance_timestep() + + def get_sim_state(self) -> Dict: + """Get the current state of the simulation.""" + return self.simulation.describe_state() + + def update_agents(self, state: Dict) -> None: + """Update agents' observations and rewards based on the current state.""" for agent in self.agents: - # 3. primaite session asks simulation to provide initial state - # 4. primate session gives state to all agents - # 5. primaite session asks agents to produce an action based on most recent state - _LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}") - sim_state = self.simulation.describe_state() + agent.update_observation(state) + agent.update_reward(state) - # 6. each agent takes most recent state and converts it to CAOS observation - agent_obs = agent.convert_state_to_obs(sim_state) + def apply_agent_actions(self) -> None: + """Apply all actions to simulation as requests.""" + for agent in self.agents: + obs = agent.observation_manager.current_observation + rew = agent.reward_function.current_reward + action_choice, options = agent.get_action(obs, rew) + request = agent.format_request(action_choice, options) + self.simulation.apply_request(request) - # 7. meanwhile each agent also takes state and calculates reward - agent_reward = agent.calculate_reward_from_state(sim_state) - - # 8. each agent takes observation and applies decision rule to observation to create CAOS - # action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action - # to discrete(40) is only necessary for purposes of RL learning, therefore that bit of - # code should live inside of the GATE agent subclass) - # gets action in CAOS format - _LOGGER.debug("Getting agent action") - agent_action, action_options = agent.get_action(agent_obs, agent_reward) - # 9. CAOS action is converted into request (extra information might be needed to enrich - # the request, this is what the execution definition is there for) - _LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements - agent_request = agent.format_request(agent_action, action_options) - - # 10. primaite session receives the action from the agents and asks the simulation to apply each - _LOGGER.debug(f"Sending request to simulation: {agent_request}") - self.simulation.apply_request(agent_request) - - _LOGGER.debug(f"Initiating simulation step {self.step_counter}") + def advance_timestep(self) -> None: + """Advance timestep.""" self.simulation.apply_timestep(self.step_counter) self.step_counter += 1 @@ -161,7 +225,7 @@ class PrimaiteSession: return NotImplemented def close(self) -> None: - """Close the session, this will stop the gate client and close the simulation.""" + """Close the session, this will stop the env and close the simulation.""" return NotImplemented @classmethod @@ -169,7 +233,7 @@ class PrimaiteSession: """Create a PrimaiteSession object from a config dictionary. The config dictionary should have the following top-level keys: - 1. training_config: options for training the RL agent. Used by GATE. + 1. training_config: options for training the RL agent. 2. game_config: options for the game itself. Used by PrimaiteSession. 3. simulation: defines the network topology and the initial state of the simulation. @@ -323,7 +387,7 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationSpace.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, sess) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] @@ -359,15 +423,14 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) - elif agent_type == "GATERLAgent": - new_agent = RandomAgent( + elif agent_type == "RLAgent": + new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=rew_function, ) sess.agents.append(new_agent) - sess.rl_agent = new_agent elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -379,4 +442,7 @@ class PrimaiteSession: else: print("agent type not found") + # CREATE POLICY + sess.policy = PolicyABC.from_config(sess.training_options) + return sess