From 6e5e1e6456a11124c147e1ad75297a3db16676ab Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 11:38:29 +0000 Subject: [PATCH 01/17] Begin rllib --- pyproject.toml | 3 ++- src/primaite/game/policy/rllib.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 src/primaite/game/policy/rllib.py diff --git a/pyproject.toml b/pyproject.toml index 1e074c25..2f8cb803 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ "tensorflow==2.12.0", "typer[all]==0.9.0", "pydantic==2.1.1", - "enlighten==1.12.2" + "enlighten==1.12.2", + "ray[rllib] == 2.8.0, < 3" ] [tool.setuptools.dynamic] diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py new file mode 100644 index 00000000..721a7500 --- /dev/null +++ b/src/primaite/game/policy/rllib.py @@ -0,0 +1,18 @@ + + +from typing import Literal, Optional, Type, TYPE_CHECKING, Union + +from primaite.game.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession, TrainingOptions + +from ray.rllib + + +class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): + """Single agent RL policy using Ray RLLib.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + super().__init__(session=session) + From 3fb7bce3ce547a1d139f297ef5e4a20f8d028d00 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 17:57:57 +0000 Subject: [PATCH 02/17] Get RLLib to stop crashing. --- .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/environment.py | 67 +++++++++++++++ src/primaite/game/policy/__init__.py | 3 +- src/primaite/game/policy/rllib.py | 81 ++++++++++++++++++- src/primaite/game/session.py | 62 +------------- 5 files changed, 149 insertions(+), 66 deletions(-) create mode 100644 src/primaite/game/environment.py diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index e0ff9276..c581ae49 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -1,5 +1,5 @@ training_config: - rl_framework: SB3 + rl_framework: RLLIB_single_agent rl_algorithm: PPO seed: 333 n_learn_episodes: 25 diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py new file mode 100644 index 00000000..b88a8202 --- /dev/null +++ b/src/primaite/game/environment.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, TYPE_CHECKING + +import gymnasium +from gymnasium.core import ActType, ObsType + +from primaite.game.agent.interface import ProxyAgent + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + + +class PrimaiteGymEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. + + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ + + def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + """Initialise the environment.""" + super().__init__() + self.session: "PrimaiteSession" = session + self.agent: ProxyAgent = agents[0] + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.session.apply_agent_actions() + self.session.advance_timestep() + state = self.session.get_sim_state() + self.session.update_agents(state) + + next_obs = self._get_obs() + reward = self.agent.reward_function.current_reward + terminated = False + truncated = self.session.calculate_truncated() + info = {} + + return next_obs, reward, terminated, truncated, info + + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: + """Reset the environment.""" + self.session.reset() + state = self.session.get_sim_state() + self.session.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.space + + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py index 29196112..9c0e4199 100644 --- a/src/primaite/game/policy/__init__.py +++ b/src/primaite/game/policy/__init__.py @@ -1,3 +1,4 @@ +from primaite.game.policy.rllib import RaySingleAgentPolicy from primaite.game.policy.sb3 import SB3Policy -__all__ = ["SB3Policy"] +__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index 721a7500..6e9e1096 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -1,13 +1,20 @@ +from pathlib import Path +from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union +import gymnasium +from gymnasium.core import ActType, ObsType -from typing import Literal, Optional, Type, TYPE_CHECKING, Union - -from primaite.game.policy import PolicyABC +from primaite.game.environment import PrimaiteGymEnv +from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: + from primaite.game.agent.interface import ProxyAgent from primaite.game.session import PrimaiteSession, TrainingOptions -from ray.rllib +import ray +from ray.rllib.algorithms import Algorithm, ppo +from ray.rllib.algorithms.ppo import PPOConfig +from ray.tune.registry import register_env class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): @@ -15,4 +22,70 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) + ray.init() + class RayPrimaiteGym(gymnasium.Env): + def __init__(self, env_config: Dict) -> None: + self.action_space = session.env.action_space + self.observation_space = session.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + obs, info = session.env.reset() + return obs, info + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + obs, reward, terminated, truncated, info = session.env.step(action) + return obs, reward, terminated, truncated, info + + ray.shutdown() + ray.init() + + config = { + "env": RayPrimaiteGym, + "env_config": {}, + "disable_env_checking": True, + "num_rollout_workers": 0, + } + + self._algo = ppo.PPO(config=config) + + # self._agent_config = (PPOConfig() + # .update_from_dict({ + # "num_gpus":0, + # "num_workers":0, + # "batch_mode":"complete_episodes", + # "framework":"torch", + # }) + # .environment( + # env="primaite", + # env_config={"session": session, "agents": session.rl_agents,}, + # # disable_env_checking=True + # ) + # # .rollouts(num_rollout_workers=0, + # # num_envs_per_worker=0) + # # .framework("tf2") + # .evaluation(evaluation_num_workers=0) + # ) + + # self._agent:Algorithm = self._agent_config.build(use_copy=False) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + + for ep in range(n_episodes): + res = self._algo.train() + print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}") + + def eval(self, n_episodes: int, deterministic: bool) -> None: + raise NotImplementedError + + def save(self, save_path: Path) -> None: + raise NotImplementedError + + def load(self, model_path: Path) -> None: + raise NotImplementedError + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": + """Create a policy from a config.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a2c04980..aae26fab 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple import enlighten -import gymnasium from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict @@ -14,6 +13,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.environment import PrimaiteGymEnv from primaite.game.io import SessionIO, SessionIOSettings from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -39,64 +39,6 @@ progress_bar_manager = enlighten.get_manager() _LOGGER = getLogger(__name__) -class PrimaiteGymEnv(gymnasium.Env): - """ - Thin wrapper env to provide agents with a gymnasium API. - - This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some - assumptions about the agent list always having a list of length 1. - """ - - def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): - """Initialise the environment.""" - super().__init__() - self.session: "PrimaiteSession" = session - self.agent: ProxyAgent = agents[0] - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """Perform a step in the environment.""" - # make ProxyAgent store the action chosen my the RL policy - self.agent.store_action(action) - # apply_agent_actions accesses the action we just stored - self.session.apply_agent_actions() - self.session.advance_timestep() - state = self.session.get_sim_state() - self.session.update_agents(state) - - next_obs = self._get_obs() - reward = self.agent.reward_function.current_reward - terminated = False - truncated = self.session.calculate_truncated() - info = {} - - return next_obs, reward, terminated, truncated, info - - def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: - """Reset the environment.""" - self.session.reset() - state = self.session.get_sim_state() - self.session.update_agents(state) - next_obs = self._get_obs() - info = {} - return next_obs, info - - @property - def action_space(self) -> gymnasium.Space: - """Return the action space of the environment.""" - return self.agent.action_manager.space - - @property - def observation_space(self) -> gymnasium.Space: - """Return the observation space of the environment.""" - return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) - - def _get_obs(self) -> ObsType: - """Return the current observation.""" - unflat_space = self.agent.observation_manager.space - unflat_obs = self.agent.observation_manager.current_observation - return gymnasium.spaces.flatten(unflat_space, unflat_obs) - - class PrimaiteSessionOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -115,7 +57,7 @@ class TrainingOptions(BaseModel): model_config = ConfigDict(extra="forbid") - rl_framework: Literal["SB3", "RLLIB"] + rl_framework: Literal["SB3", "RLLIB_single_agent"] rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int n_eval_episodes: Optional[int] = None From afd64e467403824cd7c2db80eec02cfd58a5fe46 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 11:59:25 +0000 Subject: [PATCH 03/17] Separate game, environment, and session --- .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/agent/actions.py | 8 +- src/primaite/game/agent/observations.py | 28 ++-- src/primaite/game/agent/rewards.py | 12 +- src/primaite/game/environment.py | 6 +- src/primaite/game/{session.py => game.py} | 152 ++++-------------- src/primaite/game/policy/policy.py | 8 +- src/primaite/game/policy/rllib.py | 21 ++- src/primaite/game/policy/sb3.py | 6 +- src/primaite/main.py | 4 +- .../notebooks/train_rllib_single_agent.ipynb | 129 +++++++++++++++ src/primaite/session/__init__.py | 0 src/primaite/{game => session}/io.py | 0 src/primaite/session/session.py | 92 +++++++++++ tests/conftest.py | 4 +- 15 files changed, 304 insertions(+), 168 deletions(-) rename src/primaite/game/{session.py => game.py} (78%) create mode 100644 src/primaite/notebooks/train_rllib_single_agent.ipynb create mode 100644 src/primaite/session/__init__.py rename src/primaite/{game => session}/io.py (100%) create mode 100644 src/primaite/session/session.py diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 3d918f2b..443b0efe 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: RLLIB_single_agent rl_algorithm: PPO seed: 333 - n_learn_episodes: 25 + n_learn_episodes: 1 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index b06013cd..c8095aa5 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractAction(ABC): @@ -559,7 +559,7 @@ class ActionManager: def __init__( self, - session: "PrimaiteSession", # reference to session for looking up stuff + session: "PrimaiteGame", # reference to session for looking up stuff actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node max_folders_per_node: int = 2, # allows calculating shape @@ -599,7 +599,7 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.session: "PrimaiteSession" = session + self.session: "PrimaiteGame" = session self.sim: Simulation = self.session.simulation self.node_uuids: List[str] = node_uuids self.protocols: List[str] = protocols @@ -826,7 +826,7 @@ class ActionManager: return nics[nic_idx] @classmethod - def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager": + def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager": """ Construct an ActionManager from a config definition. diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a74771c0..f57ec10d 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractObservation(ABC): @@ -37,7 +37,7 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config: Dict, session: "PrimaiteSession"): + def from_config(cls, config: Dict, session: "PrimaiteGame"): """Create this observation space component form a serialised format. The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, @@ -91,7 +91,7 @@ class FileObservation(AbstractObservation): return spaces.Dict({"health_status": spaces.Discrete(6)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": """Create file observation from a config. :param config: Dictionary containing the configuration for this file observation. @@ -149,7 +149,7 @@ class ServiceObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None + cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None ) -> "ServiceObservation": """Create service observation from a config. @@ -219,7 +219,7 @@ class LinkObservation(AbstractObservation): return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation": """Create link observation from a config. :param config: Dictionary containing the configuration for this link observation. @@ -310,7 +310,7 @@ class FolderObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 ) -> "FolderObservation": """Create folder observation from a config. Also creates child file observations. @@ -376,9 +376,7 @@ class NicObservation(AbstractObservation): return spaces.Dict({"nic_status": spaces.Discrete(3)}) @classmethod - def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] - ) -> "NicObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": """Create NIC observation from a config. :param config: Dictionary containing the configuration for this NIC observation. @@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation): def from_config( cls, config: Dict, - session: "PrimaiteSession", + session: "PrimaiteGame", parent_where: Optional[List[str]] = None, num_services_per_node: int = 2, num_folders_per_node: int = 2, @@ -694,7 +692,7 @@ class AclObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation": """Generate ACL observation from a config. :param config: Dictionary containing the configuration for this ACL observation. @@ -740,7 +738,7 @@ class NullObservation(AbstractObservation): return spaces.Discrete(1) @classmethod - def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation": + def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation": """ Create null observation from a config. @@ -836,7 +834,7 @@ class UC2BlueObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation": """Create UC2 blue observation from a config. :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, @@ -907,7 +905,7 @@ class UC2RedObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation": """ Create UC2 red observation from a config. @@ -966,7 +964,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index da1331b0..60c3678c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractReward: @@ -47,7 +47,7 @@ class AbstractReward: @classmethod @abstractmethod - def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward": + def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -68,7 +68,7 @@ class DummyReward(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward": + def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor. Should be empty. @@ -119,7 +119,7 @@ class DatabaseFileIntegrity(AbstractReward): return 0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -193,7 +193,7 @@ class WebServer404Penalty(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -265,7 +265,7 @@ class RewardFunction: return self.current_reward @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction": """Create a reward function from a config dictionary. :param config: dict of options for the reward manager's constructor diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py index b88a8202..36f808bb 100644 --- a/src/primaite/game/environment.py +++ b/src/primaite/game/environment.py @@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType from primaite.game.agent.interface import ProxyAgent if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class PrimaiteGymEnv(gymnasium.Env): @@ -17,10 +17,10 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + def __init__(self, session: "PrimaiteGame", agents: List[ProxyAgent]): """Initialise the environment.""" super().__init__() - self.session: "PrimaiteSession" = session + self.session: "PrimaiteGame" = session self.agent: ProxyAgent = agents[0] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: diff --git a/src/primaite/game/session.py b/src/primaite/game/game.py similarity index 78% rename from src/primaite/game/session.py rename to src/primaite/game/game.py index c4195925..7dd50924 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/game.py @@ -1,11 +1,7 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" -from enum import Enum from ipaddress import IPv4Address -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +from typing import Dict, List -import gymnasium -from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict from primaite import getLogger @@ -13,9 +9,6 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.environment import PrimaiteGymEnv -from primaite.game.io import SessionIO, SessionIOSettings -from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -37,7 +30,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteSessionOptions(BaseModel): +class PrimaiteGameOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -46,37 +39,17 @@ class PrimaiteSessionOptions(BaseModel): model_config = ConfigDict(extra="forbid") + max_episode_length: int = 256 ports: List[str] protocols: List[str] -class TrainingOptions(BaseModel): - """Options for training the RL agent.""" +class PrimaiteGame: + """ + Primaite game encapsulates the simulation and agents which interact with it. - model_config = ConfigDict(extra="forbid") - - rl_framework: Literal["SB3", "RLLIB_single_agent"] - rl_algorithm: Literal["PPO", "A2C"] - n_learn_episodes: int - n_eval_episodes: Optional[int] = None - max_steps_per_episode: int - # checkpoint_freq: Optional[int] = None - deterministic_eval: bool - seed: Optional[int] - n_agents: int - agent_references: List[str] - - -class SessionMode(Enum): - """Helper to keep track of the current session mode.""" - - TRAIN = "train" - EVAL = "eval" - MANUAL = "manual" - - -class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" + Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment. + """ def __init__(self): """Initialise a PrimaiteSession object.""" @@ -95,15 +68,9 @@ class PrimaiteSession: self.episode_counter: int = 0 """Current episode number.""" - self.options: PrimaiteSessionOptions + self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" - self.training_options: TrainingOptions - """Options specific to agent training.""" - - self.policy: PolicyABC - """The reinforcement learning policy.""" - self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" @@ -116,40 +83,6 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - self.env: PrimaiteGymEnv - """The environment that the agent can consume. Could be PrimaiteEnv.""" - - self.mode: SessionMode = SessionMode.MANUAL - """Current session mode.""" - - self.io_manager = SessionIO() - """IO manager for the session.""" - - def start_session(self) -> None: - """Commence the training session.""" - self.mode = SessionMode.TRAIN - n_learn_episodes = self.training_options.n_learn_episodes - n_eval_episodes = self.training_options.n_eval_episodes - max_steps_per_episode = self.training_options.max_steps_per_episode - - deterministic_eval = self.training_options.deterministic_eval - self.policy.learn( - n_episodes=n_learn_episodes, - timesteps_per_episode=max_steps_per_episode, - ) - self.save_models() - - self.mode = SessionMode.EVAL - if n_eval_episodes > 0: - self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) - - self.mode = SessionMode.MANUAL - - def save_models(self) -> None: - """Save the RL models.""" - save_path = self.io_manager.generate_model_save_path("temp_model_name") - self.policy.save(save_path) - def step(self): """ Perform one step of the simulation/agent loop. @@ -210,7 +143,7 @@ class PrimaiteSession: def calculate_truncated(self) -> bool: """Calculate whether the episode is truncated.""" current_step = self.step_counter - max_steps = self.training_options.max_steps_per_episode + max_steps = self.options.max_episode_length if current_step >= max_steps: return True return False @@ -227,8 +160,8 @@ class PrimaiteSession: return NotImplemented @classmethod - def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": - """Create a PrimaiteSession object from a config dictionary. + def from_config(cls, cfg: Dict) -> "PrimaiteGame": + """Create a PrimaiteGame object from a config dictionary. The config dictionary should have the following top-level keys: 1. training_config: options for training the RL agent. @@ -243,23 +176,16 @@ class PrimaiteSession: :return: A PrimaiteSession object. :rtype: PrimaiteSession """ - sess = cls() - sess.options = PrimaiteSessionOptions( - ports=cfg["game_config"]["ports"], - protocols=cfg["game_config"]["protocols"], - ) - sess.training_options = TrainingOptions(**cfg["training_config"]) + game = cls() + game.options = PrimaiteGameOptions(cfg["game"]) - # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - - sim = sess.simulation + # 1. create simulation + sim = game.simulation net = sim.network - sess.ref_map_nodes: Dict[str, Node] = {} - sess.ref_map_services: Dict[str, Service] = {} - sess.ref_map_links: Dict[str, Link] = {} + game.ref_map_nodes: Dict[str, Node] = {} + game.ref_map_services: Dict[str, Service] = {} + game.ref_map_links: Dict[str, Link] = {} nodes_cfg = cfg["simulation"]["network"]["nodes"] links_cfg = cfg["simulation"]["network"]["links"] @@ -323,7 +249,7 @@ class PrimaiteSession: print(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] - sess.ref_map_services[service_ref] = new_service + game.ref_map_services[service_ref] = new_service else: print(f"service type not found {service_type}") # service-dependent options @@ -348,7 +274,7 @@ class PrimaiteSession: if application_type in application_types_mapping: new_node.software_manager.install(application_types_mapping[application_type]) new_application = new_node.software_manager.software[application_type] - sess.ref_map_applications[application_ref] = new_application + game.ref_map_applications[application_ref] = new_application else: print(f"application type not found {application_type}") if "nics" in node_cfg: @@ -357,7 +283,7 @@ class PrimaiteSession: net.add_node(new_node) new_node.power_on() - sess.ref_map_nodes[ + game.ref_map_nodes[ node_ref ] = ( new_node.uuid @@ -365,8 +291,8 @@ class PrimaiteSession: # 2. create links between nodes for link_cfg in links_cfg: - node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]] - node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]] + node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]] + node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]] if isinstance(node_a, Switch): endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]] else: @@ -376,7 +302,7 @@ class PrimaiteSession: else: endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]] new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b) - sess.ref_map_links[link_cfg["ref"]] = new_link.uuid + game.ref_map_links[link_cfg["ref"]] = new_link.uuid # 3. create agents game_cfg = cfg["game_config"] @@ -390,14 +316,14 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, game) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] # if a list of nodes is defined, convert them from node references to node UUIDs for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}): if "node_ref" in action_node_option: - node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]] + node_uuid = game.ref_map_nodes[action_node_option["node_ref"]] action_space_cfg["options"]["node_uuids"].append(node_uuid) # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. @@ -409,12 +335,12 @@ class PrimaiteSession: if "options" in action_config: if "target_router_ref" in action_config["options"]: _target = action_config["options"]["target_router_ref"] - action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target] + action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target] - action_space = ActionManager.from_config(sess, action_space_cfg) + action_space = ActionManager.from_config(game, action_space_cfg) # CREATE REWARD FUNCTION - rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) + rew_function = RewardFunction.from_config(reward_function_cfg, session=game) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": @@ -425,7 +351,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) + game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -433,8 +359,8 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) - sess.rl_agents.append(new_agent) + game.agents.append(new_agent) + game.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -442,16 +368,8 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) + game.agents.append(new_agent) else: print("agent type not found") - # CREATE ENVIRONMENT - sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) - - # CREATE POLICY - sess.policy = PolicyABC.from_config(sess.training_options, session=sess) - if agent_load_path: - sess.policy.load(Path(agent_load_path)) - - return sess + return game diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 249c3b52..10af44b1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.game.game import PrimaiteGame, TrainingOptions class PolicyABC(ABC): @@ -32,7 +32,7 @@ class PolicyABC(ABC): return @abstractmethod - def __init__(self, session: "PrimaiteSession") -> None: + def __init__(self, session: "PrimaiteGame") -> None: """ Initialize a reinforcement learning policy. @@ -41,7 +41,7 @@ class PolicyABC(ABC): :param agents: The agents to train. :type agents: List[RLAgent] """ - self.session: "PrimaiteSession" = session + self.session: "PrimaiteGame" = session """Reference to the session.""" @abstractmethod @@ -69,7 +69,7 @@ class PolicyABC(ABC): pass @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC": """ Create an RL policy from a config by calling the relevant subclass's from_config method. diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index 6e9e1096..7828ccc7 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -1,26 +1,23 @@ from pathlib import Path -from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union +from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING import gymnasium from gymnasium.core import ActType, ObsType -from primaite.game.environment import PrimaiteGymEnv from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.agent.interface import ProxyAgent - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.game.game import PrimaiteGame + from primaite.session.session import TrainingOptions import ray -from ray.rllib.algorithms import Algorithm, ppo -from ray.rllib.algorithms.ppo import PPOConfig -from ray.tune.registry import register_env +from ray.rllib.algorithms import ppo class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): """Single agent RL policy using Ray RLLib.""" - def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) ray.init() @@ -71,21 +68,23 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - for ep in range(n_episodes): res = self._algo.train() print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}") def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate the agent.""" raise NotImplementedError def save(self, save_path: Path) -> None: - raise NotImplementedError + """Save the policy to a file.""" + self._algo.save(save_path) def load(self, model_path: Path) -> None: + """Load policy parameters from a file.""" raise NotImplementedError @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy": """Create a policy from a config.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index a4870054..de14ed0c 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.game.game import PrimaiteGame, TrainingOptions class SB3Policy(PolicyABC, identifier="SB3"): """Single agent RL policy using stable baselines 3.""" - def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) @@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"): self._agent = self._agent_class.load(model_path, env=self.session.env) @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy": """Create an agent from config file.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/main.py b/src/primaite/main.py index 1699fe51..5bc76ca2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -6,7 +6,7 @@ from typing import Optional, Union from primaite import getLogger from primaite.config.load import load -from primaite.game.session import PrimaiteSession +from primaite.game.game import PrimaiteGame # from primaite.primaite_session import PrimaiteSession @@ -32,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) + sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() diff --git a/src/primaite/notebooks/train_rllib_single_agent.ipynb b/src/primaite/notebooks/train_rllib_single_agent.ipynb new file mode 100644 index 00000000..3b608a52 --- /dev/null +++ b/src/primaite/notebooks/train_rllib_single_agent.ipynb @@ -0,0 +1,129 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2023-11-18 09:06:45,876\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-18 09:06:48,446\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-18 09:06:48,692\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" + ] + } + ], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "from primaite.game.environment import PrimaiteGymEnv\n", + "import yaml" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.config.load import example_config_path" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "installing DNSServer on node domain_controller\n", + "installing DatabaseClient on node web_server\n", + "installing WebServer on node web_server\n", + "installing DatabaseService on node database_server\n", + "service type not found DatabaseBackup\n", + "installing DataManipulationBot on node client_1\n", + "installing DNSClient on node client_1\n", + "installing DNSClient on node client_2\n" + ] + } + ], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "sess = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "env = sess.env" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/session/__init__.py b/src/primaite/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/io.py b/src/primaite/session/io.py similarity index 100% rename from src/primaite/game/io.py rename to src/primaite/session/io.py diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py new file mode 100644 index 00000000..d7bc3f99 --- /dev/null +++ b/src/primaite/session/session.py @@ -0,0 +1,92 @@ +from enum import Enum +from typing import Dict, List, Literal, Optional + +from pydantic import BaseModel, ConfigDict + +from primaite.game.environment import PrimaiteGymEnv + +# from primaite.game.game import PrimaiteGame +from primaite.game.policy.policy import PolicyABC +from primaite.session.io import SessionIO, SessionIOSettings + + +class TrainingOptions(BaseModel): + """Options for training the RL agent.""" + + model_config = ConfigDict(extra="forbid") + + rl_framework: Literal["SB3", "RLLIB_single_agent"] + rl_algorithm: Literal["PPO", "A2C"] + n_learn_episodes: int + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None + deterministic_eval: bool + seed: Optional[int] + n_agents: int + agent_references: List[str] + + +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" + + +class PrimaiteSession: + """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" + + def __init__(self): + """Initialise PrimaiteSession object.""" + self.training_options: TrainingOptions + """Options specific to agent training.""" + + self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.env: PrimaiteGymEnv + """The environment that the agent can consume. Could be PrimaiteEnv.""" + + self.policy: PolicyABC + """The reinforcement learning policy.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" + + def start_session(self) -> None: + """Commence the training/eval session.""" + self.mode = SessionMode.TRAIN + n_learn_episodes = self.training_options.n_learn_episodes + n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) + self.save_models() + + self.mode = SessionMode.EVAL + if n_eval_episodes > 0: + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL + + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) + + @classmethod + def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": + """Create a PrimaiteSession object from a config dictionary.""" + sess = cls() + + sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) diff --git a/tests/conftest.py b/tests/conftest.py index 6a65b12f..24001ffc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ import pytest import yaml from primaite import getLogger -from primaite.game.session import PrimaiteSession +from primaite.game.game import PrimaiteGame # from primaite.environment.primaite_env import Primaite # from primaite.primaite_session import PrimaiteSession @@ -74,7 +74,7 @@ def file_system() -> FileSystem: # PrimAITE v2 stuff -class TempPrimaiteSession(PrimaiteSession): +class TempPrimaiteSession(PrimaiteGame): """ A temporary PrimaiteSession class. From 1138644a4b3f7ea3ceb5cc261687e7726dd770a4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 12:59:33 +0000 Subject: [PATCH 04/17] Update to make things work with new layout --- .../config/_package_data/example_config.yaml | 993 +++++++++--------- src/primaite/game/environment.py | 26 +- src/primaite/game/game.py | 5 +- .../notebooks/train_rllib_single_agent.ipynb | 54 +- 4 files changed, 514 insertions(+), 564 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 443b0efe..f167dc2f 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -15,7 +15,8 @@ io_settings: checkpoint_interval: 5 -game_config: +game: + max_episode_length: 256 ports: - ARP - DNS @@ -26,523 +27,523 @@ game_config: - TCP - UDP - agents: - - ref: client_1_green_user - team: GREEN - type: GreenWebBrowsingAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: @@ -28,24 +26,24 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - self.session.apply_agent_actions() - self.session.advance_timestep() - state = self.session.get_sim_state() - self.session.update_agents(state) + self.game.apply_agent_actions() + self.game.advance_timestep() + state = self.game.get_sim_state() + self.game.update_agents(state) next_obs = self._get_obs() reward = self.agent.reward_function.current_reward terminated = False - truncated = self.session.calculate_truncated() + truncated = self.game.calculate_truncated() info = {} return next_obs, reward, terminated, truncated, info def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" - self.session.reset() - state = self.session.get_sim_state() - self.session.update_agents(state) + self.game.reset() + state = self.game.get_sim_state() + self.game.update_agents(state) next_obs = self._get_obs() info = {} return next_obs, info diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 7dd50924..e260285f 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -177,7 +177,7 @@ class PrimaiteGame: :rtype: PrimaiteSession """ game = cls() - game.options = PrimaiteGameOptions(cfg["game"]) + game.options = PrimaiteGameOptions(**cfg["game"]) # 1. create simulation sim = game.simulation @@ -305,8 +305,7 @@ class PrimaiteGame: game.ref_map_links[link_cfg["ref"]] = new_link.uuid # 3. create agents - game_cfg = cfg["game_config"] - agents_cfg = game_cfg["agents"] + agents_cfg = cfg["agents"] for agent_cfg in agents_cfg: agent_ref = agent_cfg["ref"] # noqa: F841 diff --git a/src/primaite/notebooks/train_rllib_single_agent.ipynb b/src/primaite/notebooks/train_rllib_single_agent.ipynb index 3b608a52..709e6e6f 100644 --- a/src/primaite/notebooks/train_rllib_single_agent.ipynb +++ b/src/primaite/notebooks/train_rllib_single_agent.ipynb @@ -4,19 +4,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "2023-11-18 09:06:45,876\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-18 09:06:48,446\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2023-11-18 09:06:48,692\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" - ] - } - ], + "outputs": [], "source": [ "from primaite.game.game import PrimaiteGame\n", "from primaite.game.environment import PrimaiteGymEnv\n", @@ -56,7 +44,7 @@ "with open(example_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", - "sess = PrimaiteGame.from_config(cfg)" + "game = PrimaiteGame.from_config(cfg)" ] }, { @@ -65,44 +53,8 @@ "metadata": {}, "outputs": [], "source": [ - "sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)" + "gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)" ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "env = sess.env" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "env" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From b81dd26b713f82d422b09bc666fc046626437760 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 13:12:08 +0000 Subject: [PATCH 05/17] Add Ray env class --- src/primaite/game/environment.py | 24 ++++++++-- ...agent.ipynb => training_example_sb3.ipynb} | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) rename src/primaite/notebooks/{train_rllib_single_agent.ipynb => training_example_sb3.ipynb} (68%) diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py index 57846b99..d540bd02 100644 --- a/src/primaite/game/environment.py +++ b/src/primaite/game/environment.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType @@ -15,11 +15,11 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame, agents: List[ProxyAgent]): + def __init__(self, game: PrimaiteGame): """Initialise the environment.""" super().__init__() self.game: "PrimaiteGame" = game - self.agent: ProxyAgent = agents[0] + self.agent: ProxyAgent = self.game.rl_agents[0] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -63,3 +63,21 @@ class PrimaiteGymEnv(gymnasium.Env): unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + + +class PrimaiteRayEnv(gymnasium.Env): + """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" + + def __init__(self, env_config: Dict) -> None: + """Initialise the environment.""" + self.env = PrimaiteGymEnv(game=env_config["game"]) + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + return self.env.reset(seed=seed) + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + """Perform a step in the environment.""" + return self.env.step(action) diff --git a/src/primaite/notebooks/train_rllib_single_agent.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb similarity index 68% rename from src/primaite/notebooks/train_rllib_single_agent.ipynb rename to src/primaite/notebooks/training_example_sb3.ipynb index 709e6e6f..e4033a79 100644 --- a/src/primaite/notebooks/train_rllib_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -55,6 +55,53 @@ "source": [ "gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)" ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3 import PPO" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "model = PPO('MlpPolicy', gym)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.learn(total_timesteps=1000)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "model.save(\"deleteme\")" + ] } ], "metadata": { From 9070fb44d4451b36226bd48af6e10a8fe92d5dd6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 13:26:29 +0000 Subject: [PATCH 06/17] Check that ray single agent training works --- src/primaite/game/environment.py | 9 +- .../training_example_ray_single_agent.ipynb | 129 ++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 src/primaite/notebooks/training_example_ray_single_agent.ipynb diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py index d540bd02..8ddcb88a 100644 --- a/src/primaite/game/environment.py +++ b/src/primaite/game/environment.py @@ -68,8 +68,13 @@ class PrimaiteGymEnv(gymnasium.Env): class PrimaiteRayEnv(gymnasium.Env): """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" - def __init__(self, env_config: Dict) -> None: - """Initialise the environment.""" + def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict[str, PrimaiteGame] + """ self.env = PrimaiteGymEnv(game=env_config["game"]) self.action_space = self.env.action_space self.observation_space = self.env.observation_space diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb new file mode 100644 index 00000000..f47722f5 --- /dev/null +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -0,0 +1,129 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.game.environment import PrimaiteRayEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gym = PrimaiteRayEnv({\"game\":game})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "from ray.rllib.algorithms import ppo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()\n", + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env_config = {\"game\":game}\n", + "config = {\n", + " \"env\" : PrimaiteRayEnv,\n", + " \"env_config\" : env_config,\n", + " \"disable_env_checking\": True,\n", + " \"num_rollout_workers\": 0,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo = ppo.PPO(config=config)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(5):\n", + " result = algo.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo.save(\"temp/deleteme\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 1fd5298fc56cf5b6b1f0cb155d8463bc1e25f145 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 20:22:34 +0000 Subject: [PATCH 07/17] Fix multi agent system --- .../example_config_2_rl_agents.yaml | 1164 +++++++++++++++++ src/primaite/game/environment.py | 76 +- .../training_example_ray_multi_agent.ipynb | 127 ++ 3 files changed, 1366 insertions(+), 1 deletion(-) create mode 100644 src/primaite/config/_package_data/example_config_2_rl_agents.yaml create mode 100644 src/primaite/notebooks/training_example_ray_multi_agent.ipynb diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml new file mode 100644 index 00000000..9450c419 --- /dev/null +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -0,0 +1,1164 @@ +training_config: + rl_framework: RLLIB_single_agent + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 256 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" return self.env.step(action) + + +class PrimaiteRayMARLEnv(MultiAgentEnv): + """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" + + def __init__(self, env_config: Optional[Dict] = None) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict[str, PrimaiteGame] + """ + self.game: PrimaiteGame = env_config["game"] + """Reference to the primaite game""" + self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} + """List of all possible agents in the environment. This list should not change!""" + self._agent_ids = list(self.agents.keys()) + + self.terminateds = set() + self.truncateds = set() + self.observation_space = gymnasium.spaces.Dict( + {name: agent.observation_manager.space for name, agent in self.agents.items()} + ) + self.action_space = gymnasium.spaces.Dict( + {name: agent.action_manager.space for name, agent in self.agents.items()} + ) + super().__init__() + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + self.game.reset() + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + def step( + self, actions: Dict[str, ActType] + ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: + """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. + + :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. + :type actions: Dict[str, ActType] + :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent + identifier. + :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] + """ + # 1. Perform actions + for agent_name, action in actions.items(): + self.agents[agent_name].store_action(action) + self.game.apply_agent_actions() + + # 2. Advance timestep + self.game.advance_timestep() + + # 3. Get next observations + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + + # 4. Get rewards + rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + terminateds = {name: False for name, _ in self.agents.items()} + truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} + infos = {} + terminateds["__all__"] = len(self.terminateds) == len(self.agents) + truncateds["__all__"] = self.game.calculate_truncated() + return next_obs, rewards, terminateds, truncateds, infos + + def _get_obs(self) -> Dict[str, ObsType]: + """Return the current observation.""" + return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()} diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb new file mode 100644 index 00000000..9f916af9 --- /dev/null +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.game.environment import PrimaiteRayEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# gym = PrimaiteRayEnv({\"game\":game})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "from ray import air, tune\n", + "from ray.rllib.algorithms.ppo import PPOConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()\n", + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.environment import PrimaiteRayMARLEnv\n", + "\n", + "\n", + "env_config = {\"game\":game}\n", + "config = (\n", + " PPOConfig()\n", + " .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n", + " .rollouts(num_rollout_workers=0)\n", + " .multi_agent(\n", + " policies={agent.agent_name for agent in game.rl_agents},\n", + " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " )\n", + " .training(train_batch_size=128)\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tune.Tuner(\n", + " \"PPO\",\n", + " run_config=air.RunConfig(\n", + " stop={\"training_iteration\": 128},\n", + " checkpoint_config=air.CheckpointConfig(\n", + " checkpoint_frequency=10,\n", + " ),\n", + " ),\n", + " param_space=config\n", + ").fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 14ae8be5e2705a17e9cff45560499ae0c1fa6706 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 23 Nov 2023 00:54:19 +0000 Subject: [PATCH 08/17] Update session after it was split from game --- src/primaite/game/policy/policy.py | 10 ++- src/primaite/game/policy/rllib.py | 66 +++++-------------- src/primaite/game/policy/sb3.py | 6 +- src/primaite/main.py | 10 +-- .../training_example_ray_multi_agent.ipynb | 4 +- .../training_example_ray_single_agent.ipynb | 8 ++- .../notebooks/training_example_sb3.ipynb | 50 ++++---------- src/primaite/{game => session}/environment.py | 0 src/primaite/session/session.py | 35 ++++++++-- 9 files changed, 76 insertions(+), 113 deletions(-) rename src/primaite/{game => session}/environment.py (100%) diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 10af44b1..984466d1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame, TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions class PolicyABC(ABC): @@ -32,7 +32,7 @@ class PolicyABC(ABC): return @abstractmethod - def __init__(self, session: "PrimaiteGame") -> None: + def __init__(self, session: "PrimaiteSession") -> None: """ Initialize a reinforcement learning policy. @@ -41,7 +41,7 @@ class PolicyABC(ABC): :param agents: The agents to train. :type agents: List[RLAgent] """ - self.session: "PrimaiteGame" = session + self.session: "PrimaiteSession" = session """Reference to the session.""" @abstractmethod @@ -69,7 +69,7 @@ class PolicyABC(ABC): pass @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": """ Create an RL policy from a config by calling the relevant subclass's from_config method. @@ -80,5 +80,3 @@ class PolicyABC(ABC): PolicyType = cls._registry[config.rl_framework] return PolicyType.from_config(config=config, session=session) - - # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index 7828ccc7..f45b9fd6 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -1,14 +1,11 @@ from pathlib import Path -from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING - -import gymnasium -from gymnasium.core import ActType, ObsType +from typing import Literal, Optional, TYPE_CHECKING from primaite.game.policy.policy import PolicyABC +from primaite.session.environment import PrimaiteRayEnv if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - from primaite.session.session import TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions import ray from ray.rllib.algorithms import ppo @@ -17,64 +14,33 @@ from ray.rllib.algorithms import ppo class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): """Single agent RL policy using Ray RLLib.""" - def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) - ray.init() - - class RayPrimaiteGym(gymnasium.Env): - def __init__(self, env_config: Dict) -> None: - self.action_space = session.env.action_space - self.observation_space = session.env.observation_space - - def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: - obs, info = session.env.reset() - return obs, info - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: - obs, reward, terminated, truncated, info = session.env.step(action) - return obs, reward, terminated, truncated, info - - ray.shutdown() - ray.init() config = { - "env": RayPrimaiteGym, - "env_config": {}, + "env": PrimaiteRayEnv, + "env_config": {"game": session.game}, "disable_env_checking": True, "num_rollout_workers": 0, } + ray.shutdown() + ray.init() + self._algo = ppo.PPO(config=config) - # self._agent_config = (PPOConfig() - # .update_from_dict({ - # "num_gpus":0, - # "num_workers":0, - # "batch_mode":"complete_episodes", - # "framework":"torch", - # }) - # .environment( - # env="primaite", - # env_config={"session": session, "agents": session.rl_agents,}, - # # disable_env_checking=True - # ) - # # .rollouts(num_rollout_workers=0, - # # num_envs_per_worker=0) - # # .framework("tf2") - # .evaluation(evaluation_num_workers=0) - # ) - - # self._agent:Algorithm = self._agent_config.build(use_copy=False) - def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" for ep in range(n_episodes): - res = self._algo.train() - print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}") + self._algo.train() def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" - raise NotImplementedError + for ep in range(n_episodes): + obs, info = self.session.env.reset() + for step in range(self.session.game.options.max_episode_length): + action = self._algo.compute_single_action(observation=obs, explore=False) + obs, rew, term, trunc, info = self.session.env.step(action) def save(self, save_path: Path) -> None: """Save the policy to a file.""" @@ -85,6 +51,6 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): raise NotImplementedError @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": """Create a policy from a config.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index de14ed0c..64eebfc7 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame, TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions class SB3Policy(PolicyABC, identifier="SB3"): """Single agent RL policy using stable baselines 3.""" - def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) @@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"): self._agent = self._agent_class.load(model_path, env=self.session.env) @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": """Create an agent from config file.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/main.py b/src/primaite/main.py index 5bc76ca2..b63227a7 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -5,8 +5,8 @@ from pathlib import Path from typing import Optional, Union from primaite import getLogger -from primaite.config.load import load -from primaite.game.game import PrimaiteGame +from primaite.config.load import example_config_path, load +from primaite.session.session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession @@ -32,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path) + sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() @@ -42,6 +42,6 @@ if __name__ == "__main__": args = parser.parse_args() if not args.config: - _LOGGER.error("Please provide a config file using the --config " "argument") + args.config = example_config_path() - run(session_path=args.config) + run(args.config) diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index 9f916af9..d31d53cc 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -10,7 +10,7 @@ "import yaml\n", "from primaite.config.load import example_config_path\n", "\n", - "from primaite.game.environment import PrimaiteRayEnv" + "from primaite.session.environment import PrimaiteRayEnv" ] }, { @@ -61,7 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.game.environment import PrimaiteRayMARLEnv\n", + "from primaite.session.environment import PrimaiteRayMARLEnv\n", "\n", "\n", "env_config = {\"game\":game}\n", diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index f47722f5..9b935346 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -10,7 +10,7 @@ "import yaml\n", "from primaite.config.load import example_config_path\n", "\n", - "from primaite.game.environment import PrimaiteRayEnv" + "from primaite.session.environment import PrimaiteRayEnv" ] }, { @@ -102,7 +102,11 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from primaite.config.load import example_config_path\n", + "from primaite.main import run\n", + "run(example_config_path())" + ] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e4033a79..e5085c5e 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -2,18 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from primaite.game.game import PrimaiteGame\n", - "from primaite.game.environment import PrimaiteGymEnv\n", + "from primaite.session.environment import PrimaiteGymEnv\n", "import yaml" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -22,24 +22,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "installing DNSServer on node domain_controller\n", - "installing DatabaseClient on node web_server\n", - "installing WebServer on node web_server\n", - "installing DatabaseService on node database_server\n", - "service type not found DatabaseBackup\n", - "installing DataManipulationBot on node client_1\n", - "installing DNSClient on node client_1\n", - "installing DNSClient on node client_2\n" - ] - } - ], + "outputs": [], "source": [ "with open(example_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", @@ -49,16 +34,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)" + "gym = PrimaiteGymEnv(game=game)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -67,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -76,27 +61,16 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.learn(total_timesteps=1000)\n" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/src/primaite/game/environment.py b/src/primaite/session/environment.py similarity index 100% rename from src/primaite/game/environment.py rename to src/primaite/session/environment.py diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index d7bc3f99..9f567a95 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -1,12 +1,14 @@ from enum import Enum -from typing import Dict, List, Literal, Optional +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict -from primaite.game.environment import PrimaiteGymEnv +from primaite.game.game import PrimaiteGame # from primaite.game.game import PrimaiteGame from primaite.game.policy.policy import PolicyABC +from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings @@ -15,7 +17,7 @@ class TrainingOptions(BaseModel): model_config = ConfigDict(extra="forbid") - rl_framework: Literal["SB3", "RLLIB_single_agent"] + rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"] rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int n_eval_episodes: Optional[int] = None @@ -38,7 +40,7 @@ class SessionMode(Enum): class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" - def __init__(self): + def __init__(self, game: PrimaiteGame): """Initialise PrimaiteSession object.""" self.training_options: TrainingOptions """Options specific to agent training.""" @@ -46,8 +48,8 @@ class PrimaiteSession: self.mode: SessionMode = SessionMode.MANUAL """Current session mode.""" - self.env: PrimaiteGymEnv - """The environment that the agent can consume. Could be PrimaiteEnv.""" + self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv] + """The environment that the RL algorithm can consume.""" self.policy: PolicyABC """The reinforcement learning policy.""" @@ -55,6 +57,9 @@ class PrimaiteSession: self.io_manager = SessionIO() """IO manager for the session.""" + self.game: PrimaiteGame = game + """Primaite Game object for managing main simulation loop and agents.""" + def start_session(self) -> None: """Commence the training/eval session.""" self.mode = SessionMode.TRAIN @@ -83,10 +88,26 @@ class PrimaiteSession: @classmethod def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" - sess = cls() + game = PrimaiteGame.from_config(cfg) + + sess = cls(game=game) sess.training_options = TrainingOptions(**cfg["training_config"]) # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... io_settings = cfg.get("io_settings", {}) sess.io_manager.settings = SessionIOSettings(**io_settings) + + # CREATE ENVIRONMENT + if sess.training_options.rl_framework == "RLLIB_single_agent": + sess.env = PrimaiteRayEnv(env_config={"game": game}) + elif sess.training_options.rl_framework == "RLLIB_multi_agent": + sess.env = PrimaiteRayMARLEnv(env_config={"game": game}) + elif sess.training_options.rl_framework == "SB3": + sess.env = PrimaiteGymEnv(game=game) + + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) + + return sess From 8a2279c6cb2de3638abba30f259e5111c9d3bbee Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 23 Nov 2023 01:40:27 +0000 Subject: [PATCH 09/17] Update end to end tests after session changes --- .../training_example_ray_single_agent.ipynb | 11 - .../assets/configs/bad_primaite_session.yaml | 992 +++++++++--------- .../configs/eval_only_primaite_session.yaml | 992 +++++++++--------- .../assets/configs/test_primaite_session.yaml | 992 +++++++++--------- .../configs/train_only_primaite_session.yaml | 992 +++++++++--------- tests/conftest.py | 3 +- .../test_rllib_multi_agent_environment.py | 43 + .../test_rllib_single_agent_environment.py | 38 + .../environments/test_sb3_environment.py | 27 + .../test_primaite_session.py | 10 +- 10 files changed, 2099 insertions(+), 2001 deletions(-) create mode 100644 tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py create mode 100644 tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py create mode 100644 tests/e2e_integration_tests/environments/test_sb3_environment.py diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index 9b935346..8ee16d41 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -96,17 +96,6 @@ "source": [ "algo.save(\"temp/deleteme\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from primaite.config.load import example_config_path\n", - "from primaite.main import run\n", - "run(example_config_path())" - ] } ], "metadata": { diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 80567aea..b5e43ab3 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -7,7 +7,7 @@ training_config: -game_config: +game: ports: - ARP - DNS @@ -18,523 +18,523 @@ game_config: - TCP - UDP - agents: - - ref: client_1_green_user - team: GREEN - type: GreenWebBrowsingAgent - observation_space: - type: UC2GreenObservation - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com - options: - nodes: - - node_ref: client_2 - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 - reward_function: - reward_components: - - type: DUMMY + reward_function: + reward_components: + - type: DUMMY - agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 - - ref: client_1_data_manipulation_red_bot - team: RED - type: RedDatabaseCorruptingAgent + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: - - node_ref: client_1 + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + operating_status + health_status + folders: {} - action_space: - action_list: - - type: DONOTHING - # FileSystem: # PrimAITE v2 stuff -class TempPrimaiteSession(PrimaiteGame): +class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py new file mode 100644 index 00000000..0cf245b4 --- /dev/null +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -0,0 +1,43 @@ +import ray +import yaml +from ray import air, tune +from ray.rllib.algorithms.ppo import PPOConfig + +from primaite.config.load import example_config_path +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteRayMARLEnv + + +def test_rllib_multi_agent_compatibility(): + """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" + + with open(example_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + + ray.shutdown() + ray.init() + + env_config = {"game": game} + config = ( + PPOConfig() + .environment(env=PrimaiteRayMARLEnv, env_config={"game": game}) + .rollouts(num_rollout_workers=0) + .multi_agent( + policies={agent.agent_name for agent in game.rl_agents}, + policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id, + ) + .training(train_batch_size=128) + ) + + tune.Tuner( + "PPO", + run_config=air.RunConfig( + stop={"training_iteration": 128}, + checkpoint_config=air.CheckpointConfig( + checkpoint_frequency=10, + ), + ), + param_space=config, + ).fit() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py new file mode 100644 index 00000000..ce23501a --- /dev/null +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -0,0 +1,38 @@ +import tempfile +from pathlib import Path + +import ray +import yaml +from ray.rllib.algorithms import ppo + +from primaite.config.load import example_config_path +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteRayEnv + + +def test_rllib_single_agent_compatibility(): + """Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system.""" + with open(example_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + + ray.shutdown() + ray.init() + + env_config = {"game": game} + config = { + "env": PrimaiteRayEnv, + "env_config": env_config, + "disable_env_checking": True, + "num_rollout_workers": 0, + } + + algo = ppo.PPO(config=config) + + for i in range(5): + result = algo.train() + + save_file = Path(tempfile.gettempdir()) / "ray/" + algo.save(save_file) + assert save_file.exists() diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py new file mode 100644 index 00000000..3907ff50 --- /dev/null +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -0,0 +1,27 @@ +"""Test that we can create a primaite environment and train sb3 agent with no crash.""" +import tempfile +from pathlib import Path + +import yaml +from stable_baselines3 import PPO + +from primaite.config.load import example_config_path +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv + + +def test_sb3_compatibility(): + """Test that the Gymnasium environment can be used with an SB3 agent.""" + with open(example_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + gym = PrimaiteGymEnv(game=game) + model = PPO("MlpPolicy", gym) + + model.learn(total_timesteps=1000) + + save_path = Path(tempfile.gettempdir()) / "model.zip" + model.save(save_path) + + assert (save_path).exists() diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index b6122bad..68672b51 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -18,15 +18,15 @@ class TestPrimaiteSession: raise AssertionError assert session is not None - assert session.simulation - assert len(session.agents) == 3 - assert len(session.rl_agents) == 1 + assert session.game.simulation + assert len(session.game.agents) == 3 + assert len(session.game.rl_agents) == 1 assert session.policy assert session.env - assert session.simulation.network - assert len(session.simulation.network.nodes) == 10 + assert session.game.simulation.network + assert len(session.game.simulation.network.nodes) == 10 @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_start_session(self, temp_primaite_session): From f1f516c51a4c842d10eb82ec81fc61c38e075bdd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 23 Nov 2023 02:51:31 +0000 Subject: [PATCH 10/17] Add multi agent session test --- src/primaite/game/policy/rllib.py | 51 +- tests/assets/configs/multi_agent_session.yaml | 1166 +++++++++++++++++ .../test_primaite_session.py | 7 + 3 files changed, 1223 insertions(+), 1 deletion(-) create mode 100644 tests/assets/configs/multi_agent_session.yaml diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index f45b9fd6..fcebf40d 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -2,13 +2,15 @@ from pathlib import Path from typing import Literal, Optional, TYPE_CHECKING from primaite.game.policy.policy import PolicyABC -from primaite.session.environment import PrimaiteRayEnv +from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv if TYPE_CHECKING: from primaite.session.session import PrimaiteSession, TrainingOptions import ray +from ray import air, tune from ray.rllib.algorithms import ppo +from ray.rllib.algorithms.ppo import PPOConfig class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): @@ -54,3 +56,50 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": """Create a policy from a config.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) + + +class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"): + """Mutli agent RL policy using Ray RLLib.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None): + """Initialise multi agent policy wrapper.""" + super().__init__(session=session) + + self.config = ( + PPOConfig() + .environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game}) + .rollouts(num_rollout_workers=0) + .multi_agent( + policies={agent.agent_name for agent in session.game.rl_agents}, + policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id, + ) + .training(train_batch_size=128) + ) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + tune.Tuner( + "PPO", + run_config=air.RunConfig( + stop={"training_iteration": n_episodes * timesteps_per_episode}, + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10), + ), + param_space=self.config, + ).fit() + + def load(self, model_path: Path) -> None: + """Load policy paramters from a file.""" + return NotImplemented + + def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate trained policy.""" + return NotImplemented + + def save(self, save_path: Path) -> None: + """Save policy parameters to a file.""" + return NotImplemented + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy": + """Create policy from config.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml new file mode 100644 index 00000000..9d71e093 --- /dev/null +++ b/tests/assets/configs/multi_agent_session.yaml @@ -0,0 +1,1166 @@ +training_config: + rl_framework: RLLIB_multi_agent + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 2 + n_eval_episodes: 1 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: #not used :( + - defender1 + - defender2 + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + + +game: + max_episode_length: 128 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # Date: Thu, 23 Nov 2023 03:07:39 +0000 Subject: [PATCH 11/17] Update doc page on primaite session. --- docs/source/primaite_session.rst | 215 +++---------------------------- 1 file changed, 18 insertions(+), 197 deletions(-) diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 472a361f..a0b53c7d 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -7,207 +7,28 @@ Run a PrimAITE Session ====================== +``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file, +no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment, +policy, simulator, agents, and IO. + +If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment` +module directly without running a session. + + + Run --- -A PrimAITE session can be ran either with the ``primaite session`` command from the cli +A PrimAITE session can started either with the ``primaite session`` command from the cli (See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook. -Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters. -.. note:: - 🚧 *UNDER CONSTRUCTION* 🚧 +There are two parameters that can be specified: + - ``--config``: The path to the config file to use. If not specified, the default config file is used. + - ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created. -.. - .. code-block:: bash - :caption: Unix CLI +Outputs +------- - cd ~/primaite/2.0.0 - source ./.venv/bin/activate - primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml - - .. code-block:: powershell - :caption: Powershell CLI - - cd ~\primaite\2.0.0 - .\.venv\Scripts\activate - primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml - - - .. code-block:: python - :caption: Python - - from primaite.main import run - - training_config = - lay_down_config = - run(training_config, lay_down_config) - - When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``). - The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions//_/`` - - For example, when running a session at 17:30:00 on 31st January 2023, the session will output to: - ``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``. - - ``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``. - - To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options. - - - - .. code-block:: bash - :caption: Unix CLI - - cd ~/primaite/2.0.0 - source ./.venv/bin/activate - primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc - - .. code-block:: powershell - :caption: Powershell CLI - - cd ~\primaite\2.0.0 - .\.venv\Scripts\activate - primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc - - - .. code-block:: python - :caption: Python - - from primaite.main import run - - training_config = - lay_down_config = - run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True) - - - - - Outputs - ------- - - PrimAITE produces four types of outputs: - - * Session Metadata - * Results - * Diagrams - * Saved agents (training checkpoints and a final trained agent) - - - **Session Metadata** - - PrimAITE creates a ``session_metadata.json`` file that contains the following metadata: - - * **uuid** - The UUID assigned to the session upon instantiation. - * **start_datetime** - The date & time the session started in iso format. - * **end_datetime** - The date & time the session ended in iso format. - * **learning** - * **total_episodes** - The total number of training episodes completed. - * **total_time_steps** - The total number of training time steps completed. - * **evaluation** - * **total_episodes** - The total number of evaluation episodes completed. - * **total_time_steps** - The total number of evaluation time steps completed. - * **env** - * **training_config** - * **All training config items** - * **lay_down_config** - * **All lay down config items** - - - **Results** - - PrimAITE automatically creates two sets of results from each learning and evaluation session: - - * Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value - * All transactions - a csv file listing the following values for every step of every episode: - - * Timestamp - * Episode number - * Step number - * Reward value - * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X - * Initial observation space (what the blue agent observed when it decided its action) - - **Diagrams** - - * For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration. - * For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory. - - **Saved agents** - - For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state. - - **Example Session Directory Structure** - - .. code-block:: text - - ~/ - └── primaite/ - └── 2.0.0/ - └── sessions/ - └── 2023-07-18/ - └── 2023-07-18_11-06-04/ - ├── evaluation/ - │ ├── all_transactions_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv - │ └── average_reward_per_episode_2023-07-18_11-06-04.png - ├── learning/ - │ ├── all_transactions_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.csv - │ ├── average_reward_per_episode_2023-07-18_11-06-04.png - │ ├── checkpoints/ - │ │ └── sb3ppo_10.zip - │ ├── SB3_PPO.zip - │ └── tensorboard_logs/ - │ ├── PPO_1/ - │ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0 - │ ├── PPO_2/ - │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1 - │ ├── PPO_3/ - │ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2 - │ ├── PPO_4/ - │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3 - │ ├── PPO_5/ - │ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4 - │ ├── PPO_6/ - │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5 - │ ├── PPO_7/ - │ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6 - │ ├── PPO_8/ - │ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7 - │ ├── PPO_9/ - │ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8 - │ └── PPO_10/ - │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9 - ├── network_2023-07-18_11-06-04.png - └── session_metadata.json - - Loading a session - ----------------- - - A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli - (See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path. - - .. tabs:: - - .. code-tab:: bash - :caption: Unix CLI - - cd ~/primaite/2.0.0 - source ./.venv/bin/activate - primaite session --load "path/to/session" - - .. code-tab:: bash - :caption: Powershell CLI - - cd ~\primaite\2.0.0 - .\.venv\Scripts\activate - primaite session --load "path\to\session" - - - .. code-tab:: python - :caption: Python - - from primaite.main import run - - run(session_path=) - - When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory +Running a session creates a session outputs directory in your user data foler. The format looks like this: +``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folders contains simulation sys logs generated by each node, +and the saved agent checkpoints, and final model. From bd109a7cfc3fdbd1a5454fa6b7c6cc714f753e33 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 09:14:55 +0000 Subject: [PATCH 12/17] Complete session->game rename refactor --- src/primaite/game/agent/actions.py | 24 +++--- src/primaite/game/agent/observations.py | 106 ++++++++++++------------ src/primaite/game/agent/rewards.py | 38 ++++----- src/primaite/game/game.py | 20 ++--- 4 files changed, 94 insertions(+), 94 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index c8095aa5..35468098 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -43,7 +43,7 @@ class AbstractAction(ABC): """Dictionary describing the number of options for each parameter of this action. The keys of this dict must align with the keyword args of the form_request method.""" self.manager: ActionManager = manager - """Reference to the ActionManager which created this action. This is used to access the session and simulation + """Reference to the ActionManager which created this action. This is used to access the game and simulation objects.""" @abstractmethod @@ -559,7 +559,7 @@ class ActionManager: def __init__( self, - session: "PrimaiteGame", # reference to session for looking up stuff + game: "PrimaiteGame", # reference to game for information lookup actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node max_folders_per_node: int = 2, # allows calculating shape @@ -574,8 +574,8 @@ class ActionManager: ) -> None: """Init method for ActionManager. - :param session: Reference to the session to which the agent belongs. - :type session: PrimaiteSession + :param game: Reference to the game to which the agent belongs. + :type game: PrimaiteGame :param actions: List of action types which should be made available to the agent. :type actions: List[str] :param node_uuids: List of node UUIDs that this agent can act on. @@ -599,8 +599,8 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.session: "PrimaiteGame" = session - self.sim: Simulation = self.session.simulation + self.game: "PrimaiteGame" = game + self.sim: Simulation = self.game.simulation self.node_uuids: List[str] = node_uuids self.protocols: List[str] = protocols self.ports: List[str] = ports @@ -826,7 +826,7 @@ class ActionManager: return nics[nic_idx] @classmethod - def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager": + def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": """ Construct an ActionManager from a config definition. @@ -845,20 +845,20 @@ class ActionManager: These options are used to calculate the shape of the action space, and to provide additional information to the ActionManager which is required to convert the agent's action choice into a CAOS request. - :param session: The Primaite Session to which the agent belongs. - :type session: PrimaiteSession + :param game: The Primaite Game to which the agent belongs. + :type game: PrimaiteGame :param cfg: The action space config. :type cfg: Dict :return: The constructed ActionManager. :rtype: ActionManager """ obj = cls( - session=session, + game=game, actions=cfg["action_list"], # node_uuids=cfg["options"]["node_uuids"], **cfg["options"], - protocols=session.options.protocols, - ports=session.options.ports, + protocols=game.options.protocols, + ports=game.options.ports, ip_address_list=None, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index f57ec10d..14fb2fa7 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -37,10 +37,10 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config: Dict, session: "PrimaiteGame"): + def from_config(cls, config: Dict, game: "PrimaiteGame"): """Create this observation space component form a serialised format. - The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, + The `game` parameter is for a the PrimaiteGame object that spawns this component. During deserialisation, a subclass of this class may need to translate from a 'reference' to a UUID. """ pass @@ -91,13 +91,13 @@ class FileObservation(AbstractObservation): return spaces.Dict({"health_status": spaces.Discrete(6)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": """Create file observation from a config. :param config: Dictionary containing the configuration for this file observation. :type config: Dict - :param session: _description_ - :type session: PrimaiteSession + :param game: _description_ + :type game: PrimaiteGame :param parent_where: _description_, defaults to None :type parent_where: _type_, optional :return: _description_ @@ -149,20 +149,20 @@ class ServiceObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None ) -> "ServiceObservation": """Create service observation from a config. :param config: Dictionary containing the configuration for this service observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. :type parent_where: Optional[List[str]], optional :return: Constructed service observation :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid]) + return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid]) class LinkObservation(AbstractObservation): @@ -219,17 +219,17 @@ class LinkObservation(AbstractObservation): return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": """Create link observation from a config. :param config: Dictionary containing the configuration for this link observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Constructed link observation :rtype: LinkObservation """ - return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]]) + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) class FolderObservation(AbstractObservation): @@ -310,15 +310,15 @@ class FolderObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 ) -> "FolderObservation": """Create folder observation from a config. Also creates child file observations. :param config: Dictionary containing the configuration for this folder observation. Includes the name of the folder and the files inside of it. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this folder's parent node. A typical location for a node ``where`` can be: ['network','nodes',,'file_system'] @@ -332,7 +332,7 @@ class FolderObservation(AbstractObservation): where = parent_where + ["folders", config["folder_name"]] file_configs = config["files"] - files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) @@ -376,13 +376,13 @@ class NicObservation(AbstractObservation): return spaces.Dict({"nic_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": """Create NIC observation from a config. :param config: Dictionary containing the configuration for this NIC observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent node. A typical location for a node ``where`` can be: ['network','nodes',] :type parent_where: Optional[List[str]] @@ -513,7 +513,7 @@ class NodeObservation(AbstractObservation): def from_config( cls, config: Dict, - session: "PrimaiteGame", + game: "PrimaiteGame", parent_where: Optional[List[str]] = None, num_services_per_node: int = 2, num_folders_per_node: int = 2, @@ -524,8 +524,8 @@ class NodeObservation(AbstractObservation): :param config: Dictionary containing the configuration for this node observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :param parent_where: Where in the simulation state dictionary to find the information about this node's parent network. A typical location for it would be: ['network',] :type parent_where: Optional[List[str]] @@ -541,24 +541,24 @@ class NodeObservation(AbstractObservation): :return: Constructed node observation :rtype: NodeObservation """ - node_uuid = session.ref_map_nodes[config["node_ref"]] + node_uuid = game.ref_map_nodes[config["node_ref"]] if parent_where is None: where = ["network", "nodes", node_uuid] else: where = parent_where + ["nodes", node_uuid] svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs] + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] folder_configs = config.get("folders", {}) folders = [ FolderObservation.from_config( - config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder + config=c, game=game, parent_where=where, num_files_per_folder=num_files_per_folder ) for c in folder_configs ] - nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys() + nic_uuids = game.simulation.network.nodes[node_uuid].nics.keys() nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else [] - nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs] + nics = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] logon_status = config.get("logon_status", False) return cls( where=where, @@ -692,13 +692,13 @@ class AclObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": """Generate ACL observation from a config. :param config: Dictionary containing the configuration for this ACL observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Observation object :rtype: AclObservation """ @@ -707,15 +707,15 @@ class AclObservation(AbstractObservation): for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): node_ref = ip_map_config["node_ref"] nic_num = ip_map_config["nic_num"] - node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] nic_obj = node_obj.ethernet_port[nic_num] node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - router_uuid = session.ref_map_nodes[config["router_node_ref"]] + router_uuid = game.ref_map_nodes[config["router_node_ref"]] return cls( node_ip_to_id=node_ip_to_idx, - ports=session.options.ports, - protocols=session.options.protocols, + ports=game.options.ports, + protocols=game.options.protocols, where=["network", "nodes", router_uuid, "acl", "acl"], num_rules=max_acl_rules, ) @@ -738,7 +738,7 @@ class NullObservation(AbstractObservation): return spaces.Discrete(1) @classmethod - def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation": + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": """ Create null observation from a config. @@ -834,14 +834,14 @@ class UC2BlueObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": """Create UC2 blue observation from a config. :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, links, ACL and ICS observations. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame :return: Constructed UC2 blue observation :rtype: UC2BlueObservation """ @@ -853,7 +853,7 @@ class UC2BlueObservation(AbstractObservation): nodes = [ NodeObservation.from_config( config=n, - session=session, + game=game, num_services_per_node=num_services_per_node, num_folders_per_node=num_folders_per_node, num_files_per_folder=num_files_per_folder, @@ -863,13 +863,13 @@ class UC2BlueObservation(AbstractObservation): ] link_configs = config["links"] - links = [LinkObservation.from_config(config=link, session=session) for link in link_configs] + links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, session=session) + acl = AclObservation.from_config(config=acl_config, game=game) ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, session=session) + ics = ICSObservation.from_config(config=ics_config, game=game) new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) return new @@ -905,17 +905,17 @@ class UC2RedObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": """ Create UC2 red observation from a config. :param config: Dictionary containing the configuration for this UC2 red observation. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame """ node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs] + nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] return cls(nodes=nodes, where=["network"]) @@ -964,7 +964,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. @@ -972,14 +972,14 @@ class ObservationManager: UC2BlueObservation, UC2RedObservation, UC2GreenObservation) The other key is 'options' which are passed to the constructor of the selected observation class. :type config: Dict - :param session: Reference to the PrimaiteSession object that spawned this observation. - :type session: PrimaiteSession + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame """ if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session)) + return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) else: raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 60c3678c..8a1c2da4 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,13 +47,13 @@ class AbstractReward: @classmethod @abstractmethod - def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward": + def from_config(cls, config: dict, game: "PrimaiteGame") -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: AbstractReward """ @@ -68,13 +68,13 @@ class DummyReward(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward": + def from_config(cls, config: dict, game: "PrimaiteGame") -> "DummyReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor. Should be empty. :type config: dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame """ return cls() @@ -119,13 +119,13 @@ class DatabaseFileIntegrity(AbstractReward): return 0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "DatabaseFileIntegrity": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: DatabaseFileIntegrity """ @@ -147,7 +147,7 @@ class DatabaseFileIntegrity(AbstractReward): f"{cls.__name__} could not be initialised from config because file_name parameter was not specified" ) return DummyReward() # TODO: better error handling - node_uuid = session.ref_map_nodes[node_ref] + node_uuid = game.ref_map_nodes[node_ref] if not node_uuid: _LOGGER.error( ( @@ -193,13 +193,13 @@ class WebServer404Penalty(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "WebServer404Penalty": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward component. :rtype: WebServer404Penalty """ @@ -212,8 +212,8 @@ class WebServer404Penalty(AbstractReward): ) _LOGGER.warn(msg) return DummyReward() # TODO: should we error out with incorrect inputs? Probably! - node_uuid = session.ref_map_nodes[node_ref] - service_uuid = session.ref_map_services[service_ref].uuid + node_uuid = game.ref_map_nodes[node_ref] + service_uuid = game.ref_map_services[service_ref].uuid if not (node_uuid and service_uuid): msg = ( f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" @@ -265,13 +265,13 @@ class RewardFunction: return self.current_reward @classmethod - def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction": + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "RewardFunction": """Create a reward function from a config dictionary. :param config: dict of options for the reward manager's constructor :type config: Dict - :param session: Reference to the PrimAITE Session object - :type session: PrimaiteSession + :param game: Reference to the PrimAITE Game object + :type game: PrimaiteGame :return: The reward manager. :rtype: RewardFunction """ @@ -281,6 +281,6 @@ class RewardFunction: rew_type = rew_component_cfg["type"] weight = rew_component_cfg.get("weight", 1.0) rew_class = cls.__rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session) + rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), game=game) new.regsiter_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e260285f..fa17b94b 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,4 +1,4 @@ -"""PrimAITE session - the main entry point to training agents on PrimAITE.""" +"""PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address from typing import Dict, List @@ -52,7 +52,7 @@ class PrimaiteGame: """ def __init__(self): - """Initialise a PrimaiteSession object.""" + """Initialise a PrimaiteGame object.""" self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" @@ -101,7 +101,7 @@ class PrimaiteGame: single-agent gym, make sure to update the ProxyAgent's action with the action before calling ``self.apply_agent_actions()``. """ - _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") + _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") # Get the current state of the simulation sim_state = self.get_sim_state() @@ -149,14 +149,14 @@ class PrimaiteGame: return False def reset(self) -> None: - """Reset the session, this will reset the simulation.""" + """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 - _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + _LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}") self.simulation.reset_component_for_episode(self.episode_counter) def close(self) -> None: - """Close the session, this will stop the env and close the simulation.""" + """Close the game, this will close the simulation.""" return NotImplemented @classmethod @@ -165,7 +165,7 @@ class PrimaiteGame: The config dictionary should have the following top-level keys: 1. training_config: options for training the RL agent. - 2. game_config: options for the game itself. Used by PrimaiteSession. + 2. game_config: options for the game itself. Used by PrimaiteGame. 3. simulation: defines the network topology and the initial state of the simulation. The specification for each of the three major areas is described in a separate documentation page. @@ -173,8 +173,8 @@ class PrimaiteGame: :param cfg: The config dictionary. :type cfg: dict - :return: A PrimaiteSession object. - :rtype: PrimaiteSession + :return: A PrimaiteGame object. + :rtype: PrimaiteGame """ game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) @@ -339,7 +339,7 @@ class PrimaiteGame: action_space = ActionManager.from_config(game, action_space_cfg) # CREATE REWARD FUNCTION - rew_function = RewardFunction.from_config(reward_function_cfg, session=game) + rew_function = RewardFunction.from_config(reward_function_cfg, game=game) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": From 50c9ef16cbca4757a493e67bf4632fe1c984a55d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 09:18:18 +0000 Subject: [PATCH 13/17] Move policy module into session --- src/primaite/game/policy/__init__.py | 4 ---- src/primaite/session/policy/__init__.py | 4 ++++ src/primaite/{game => session}/policy/policy.py | 0 src/primaite/{game => session}/policy/rllib.py | 2 +- src/primaite/{game => session}/policy/sb3.py | 2 +- src/primaite/session/session.py | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) delete mode 100644 src/primaite/game/policy/__init__.py create mode 100644 src/primaite/session/policy/__init__.py rename src/primaite/{game => session}/policy/policy.py (100%) rename src/primaite/{game => session}/policy/rllib.py (98%) rename src/primaite/{game => session}/policy/sb3.py (98%) diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py deleted file mode 100644 index 9c0e4199..00000000 --- a/src/primaite/game/policy/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from primaite.game.policy.rllib import RaySingleAgentPolicy -from primaite.game.policy.sb3 import SB3Policy - -__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/session/policy/__init__.py b/src/primaite/session/policy/__init__.py new file mode 100644 index 00000000..811c7a54 --- /dev/null +++ b/src/primaite/session/policy/__init__.py @@ -0,0 +1,4 @@ +from primaite.session.policy.rllib import RaySingleAgentPolicy +from primaite.session.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/session/policy/policy.py similarity index 100% rename from src/primaite/game/policy/policy.py rename to src/primaite/session/policy/policy.py diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/session/policy/rllib.py similarity index 98% rename from src/primaite/game/policy/rllib.py rename to src/primaite/session/policy/rllib.py index fcebf40d..7ba3edd0 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/session/policy/rllib.py @@ -1,8 +1,8 @@ from pathlib import Path from typing import Literal, Optional, TYPE_CHECKING -from primaite.game.policy.policy import PolicyABC from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.session.policy.policy import PolicyABC if TYPE_CHECKING: from primaite.session.session import PrimaiteSession, TrainingOptions diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/session/policy/sb3.py similarity index 98% rename from src/primaite/game/policy/sb3.py rename to src/primaite/session/policy/sb3.py index 64eebfc7..051e2770 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -8,7 +8,7 @@ from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP -from primaite.game.policy.policy import PolicyABC +from primaite.session.policy.policy import PolicyABC if TYPE_CHECKING: from primaite.session.session import PrimaiteSession, TrainingOptions diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 9f567a95..80b63ba7 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -5,12 +5,12 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict from primaite.game.game import PrimaiteGame - -# from primaite.game.game import PrimaiteGame -from primaite.game.policy.policy import PolicyABC from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings +# from primaite.game.game import PrimaiteGame +from primaite.session.policy.policy import PolicyABC + class TrainingOptions(BaseModel): """Options for training the RL agent.""" From 6754dbf54166d52d248f3b3218bd46138ec56bc1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 09:28:50 +0000 Subject: [PATCH 14/17] Remove GATE and fix a few spelling mistakes. --- docs/source/primaite_session.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index a0b53c7d..f3ef0399 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -19,7 +19,7 @@ module directly without running a session. Run --- -A PrimAITE session can started either with the ``primaite session`` command from the cli +A PrimAITE session can be started either with the ``primaite session`` command from the cli (See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook. There are two parameters that can be specified: @@ -29,6 +29,6 @@ There are two parameters that can be specified: Outputs ------- -Running a session creates a session outputs directory in your user data foler. The format looks like this: -``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folders contains simulation sys logs generated by each node, -and the saved agent checkpoints, and final model. +Running a session creates a session output directory in your user data folder. The filepath looks like this: +``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node, +the saved agent checkpoints, and final model. From abba1ef86b26928cda91f147cc6cba61097c4c2f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 09:37:26 +0000 Subject: [PATCH 15/17] Remove hardcoded checkpoint frequency in rllib --- src/primaite/session/policy/rllib.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/primaite/session/policy/rllib.py b/src/primaite/session/policy/rllib.py index 7ba3edd0..be181797 100644 --- a/src/primaite/session/policy/rllib.py +++ b/src/primaite/session/policy/rllib.py @@ -78,17 +78,18 @@ class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" + checkpoint_freq = self.session.io_manager.settings.checkpoint_interval tune.Tuner( "PPO", run_config=air.RunConfig( stop={"training_iteration": n_episodes * timesteps_per_episode}, - checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10), + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq), ), param_space=self.config, ).fit() def load(self, model_path: Path) -> None: - """Load policy paramters from a file.""" + """Load policy parameters from a file.""" return NotImplemented def eval(self, n_episodes: int, deterministic: bool) -> None: From d8975078b32b72bff5b36bb2c209ff327cab1d54 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 10:50:10 +0000 Subject: [PATCH 16/17] Fix game reset test. --- tests/e2e_integration_tests/test_primaite_session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 17d8a4d1..5ca99cfc 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -79,12 +79,12 @@ class TestPrimaiteSession: def test_session_sim_reset(self, temp_primaite_session): with temp_primaite_session as session: session: TempPrimaiteSession - client_1 = session.simulation.network.get_node_by_hostname("client_1") + client_1 = session.game.simulation.network.get_node_by_hostname("client_1") client_1.software_manager.uninstall("DataManipulationBot") assert "DataManipulationBot" not in client_1.software_manager.software - session.reset() - client_1 = session.simulation.network.get_node_by_hostname("client_1") + session.game.reset() + client_1 = session.game.simulation.network.get_node_by_hostname("client_1") assert "DataManipulationBot" in client_1.software_manager.software From 64c7dd3c84394e3f5831d55d22939ddb3c9b0b71 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 12:03:46 +0000 Subject: [PATCH 17/17] Skip slow tests for now. --- .../environments/test_rllib_multi_agent_environment.py | 2 ++ .../environments/test_rllib_single_agent_environment.py | 2 ++ tests/e2e_integration_tests/test_primaite_session.py | 1 + 3 files changed, 5 insertions(+) diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 0cf245b4..3934ce5b 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -1,3 +1,4 @@ +import pytest import ray import yaml from ray import air, tune @@ -8,6 +9,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteRayMARLEnv +@pytest.mark.skip(reason="Slow, reenable later") def test_rllib_multi_agent_compatibility(): """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index ce23501a..2b12ad98 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -1,6 +1,7 @@ import tempfile from pathlib import Path +import pytest import ray import yaml from ray.rllib.algorithms import ppo @@ -10,6 +11,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteRayEnv +@pytest.mark.skip(reason="Slow, reenable later") def test_rllib_single_agent_compatibility(): """Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system.""" with open(example_config_path(), "r") as f: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 5ca99cfc..086e9af8 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -65,6 +65,7 @@ class TestPrimaiteSession: session.start_session() # TODO: include checks that the model was loaded and that the eval-only session ran + @pytest.mark.skip(reason="Slow, reenable later") @pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True) def test_multi_agent_session(self, temp_primaite_session): """Check that we can run a training session with a multi agent system."""