From afd64e467403824cd7c2db80eec02cfd58a5fe46 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 11:59:25 +0000 Subject: [PATCH] Separate game, environment, and session --- .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/agent/actions.py | 8 +- src/primaite/game/agent/observations.py | 28 ++-- src/primaite/game/agent/rewards.py | 12 +- src/primaite/game/environment.py | 6 +- src/primaite/game/{session.py => game.py} | 152 ++++-------------- src/primaite/game/policy/policy.py | 8 +- src/primaite/game/policy/rllib.py | 21 ++- src/primaite/game/policy/sb3.py | 6 +- src/primaite/main.py | 4 +- .../notebooks/train_rllib_single_agent.ipynb | 129 +++++++++++++++ src/primaite/session/__init__.py | 0 src/primaite/{game => session}/io.py | 0 src/primaite/session/session.py | 92 +++++++++++ tests/conftest.py | 4 +- 15 files changed, 304 insertions(+), 168 deletions(-) rename src/primaite/game/{session.py => game.py} (78%) create mode 100644 src/primaite/notebooks/train_rllib_single_agent.ipynb create mode 100644 src/primaite/session/__init__.py rename src/primaite/{game => session}/io.py (100%) create mode 100644 src/primaite/session/session.py diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 3d918f2b..443b0efe 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: RLLIB_single_agent rl_algorithm: PPO seed: 333 - n_learn_episodes: 25 + n_learn_episodes: 1 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index b06013cd..c8095aa5 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractAction(ABC): @@ -559,7 +559,7 @@ class ActionManager: def __init__( self, - session: "PrimaiteSession", # reference to session for looking up stuff + session: "PrimaiteGame", # reference to session for looking up stuff actions: List[str], # stores list of actions available to agent node_uuids: List[str], # allows mapping index to node max_folders_per_node: int = 2, # allows calculating shape @@ -599,7 +599,7 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.session: "PrimaiteSession" = session + self.session: "PrimaiteGame" = session self.sim: Simulation = self.session.simulation self.node_uuids: List[str] = node_uuids self.protocols: List[str] = protocols @@ -826,7 +826,7 @@ class ActionManager: return nics[nic_idx] @classmethod - def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager": + def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager": """ Construct an ActionManager from a config definition. diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a74771c0..f57ec10d 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractObservation(ABC): @@ -37,7 +37,7 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config: Dict, session: "PrimaiteSession"): + def from_config(cls, config: Dict, session: "PrimaiteGame"): """Create this observation space component form a serialised format. The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation, @@ -91,7 +91,7 @@ class FileObservation(AbstractObservation): return spaces.Dict({"health_status": spaces.Discrete(6)}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": """Create file observation from a config. :param config: Dictionary containing the configuration for this file observation. @@ -149,7 +149,7 @@ class ServiceObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None + cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None ) -> "ServiceObservation": """Create service observation from a config. @@ -219,7 +219,7 @@ class LinkObservation(AbstractObservation): return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation": """Create link observation from a config. :param config: Dictionary containing the configuration for this link observation. @@ -310,7 +310,7 @@ class FolderObservation(AbstractObservation): @classmethod def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 ) -> "FolderObservation": """Create folder observation from a config. Also creates child file observations. @@ -376,9 +376,7 @@ class NicObservation(AbstractObservation): return spaces.Dict({"nic_status": spaces.Discrete(3)}) @classmethod - def from_config( - cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] - ) -> "NicObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": """Create NIC observation from a config. :param config: Dictionary containing the configuration for this NIC observation. @@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation): def from_config( cls, config: Dict, - session: "PrimaiteSession", + session: "PrimaiteGame", parent_where: Optional[List[str]] = None, num_services_per_node: int = 2, num_folders_per_node: int = 2, @@ -694,7 +692,7 @@ class AclObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation": """Generate ACL observation from a config. :param config: Dictionary containing the configuration for this ACL observation. @@ -740,7 +738,7 @@ class NullObservation(AbstractObservation): return spaces.Discrete(1) @classmethod - def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation": + def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation": """ Create null observation from a config. @@ -836,7 +834,7 @@ class UC2BlueObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation": """Create UC2 blue observation from a config. :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, @@ -907,7 +905,7 @@ class UC2RedObservation(AbstractObservation): ) @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation": """ Create UC2 red observation from a config. @@ -966,7 +964,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager": """Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index da1331b0..60c3678c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class AbstractReward: @@ -47,7 +47,7 @@ class AbstractReward: @classmethod @abstractmethod - def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward": + def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -68,7 +68,7 @@ class DummyReward(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward": + def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor. Should be empty. @@ -119,7 +119,7 @@ class DatabaseFileIntegrity(AbstractReward): return 0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -193,7 +193,7 @@ class WebServer404Penalty(AbstractReward): return 0.0 @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -265,7 +265,7 @@ class RewardFunction: return self.current_reward @classmethod - def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": + def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction": """Create a reward function from a config dictionary. :param config: dict of options for the reward manager's constructor diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py index b88a8202..36f808bb 100644 --- a/src/primaite/game/environment.py +++ b/src/primaite/game/environment.py @@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType from primaite.game.agent.interface import ProxyAgent if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession + from primaite.game.game import PrimaiteGame class PrimaiteGymEnv(gymnasium.Env): @@ -17,10 +17,10 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + def __init__(self, session: "PrimaiteGame", agents: List[ProxyAgent]): """Initialise the environment.""" super().__init__() - self.session: "PrimaiteSession" = session + self.session: "PrimaiteGame" = session self.agent: ProxyAgent = agents[0] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: diff --git a/src/primaite/game/session.py b/src/primaite/game/game.py similarity index 78% rename from src/primaite/game/session.py rename to src/primaite/game/game.py index c4195925..7dd50924 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/game.py @@ -1,11 +1,7 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" -from enum import Enum from ipaddress import IPv4Address -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +from typing import Dict, List -import gymnasium -from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict from primaite import getLogger @@ -13,9 +9,6 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.environment import PrimaiteGymEnv -from primaite.game.io import SessionIO, SessionIOSettings -from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -37,7 +30,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteSessionOptions(BaseModel): +class PrimaiteGameOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -46,37 +39,17 @@ class PrimaiteSessionOptions(BaseModel): model_config = ConfigDict(extra="forbid") + max_episode_length: int = 256 ports: List[str] protocols: List[str] -class TrainingOptions(BaseModel): - """Options for training the RL agent.""" +class PrimaiteGame: + """ + Primaite game encapsulates the simulation and agents which interact with it. - model_config = ConfigDict(extra="forbid") - - rl_framework: Literal["SB3", "RLLIB_single_agent"] - rl_algorithm: Literal["PPO", "A2C"] - n_learn_episodes: int - n_eval_episodes: Optional[int] = None - max_steps_per_episode: int - # checkpoint_freq: Optional[int] = None - deterministic_eval: bool - seed: Optional[int] - n_agents: int - agent_references: List[str] - - -class SessionMode(Enum): - """Helper to keep track of the current session mode.""" - - TRAIN = "train" - EVAL = "eval" - MANUAL = "manual" - - -class PrimaiteSession: - """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" + Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment. + """ def __init__(self): """Initialise a PrimaiteSession object.""" @@ -95,15 +68,9 @@ class PrimaiteSession: self.episode_counter: int = 0 """Current episode number.""" - self.options: PrimaiteSessionOptions + self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" - self.training_options: TrainingOptions - """Options specific to agent training.""" - - self.policy: PolicyABC - """The reinforcement learning policy.""" - self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" @@ -116,40 +83,6 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - self.env: PrimaiteGymEnv - """The environment that the agent can consume. Could be PrimaiteEnv.""" - - self.mode: SessionMode = SessionMode.MANUAL - """Current session mode.""" - - self.io_manager = SessionIO() - """IO manager for the session.""" - - def start_session(self) -> None: - """Commence the training session.""" - self.mode = SessionMode.TRAIN - n_learn_episodes = self.training_options.n_learn_episodes - n_eval_episodes = self.training_options.n_eval_episodes - max_steps_per_episode = self.training_options.max_steps_per_episode - - deterministic_eval = self.training_options.deterministic_eval - self.policy.learn( - n_episodes=n_learn_episodes, - timesteps_per_episode=max_steps_per_episode, - ) - self.save_models() - - self.mode = SessionMode.EVAL - if n_eval_episodes > 0: - self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) - - self.mode = SessionMode.MANUAL - - def save_models(self) -> None: - """Save the RL models.""" - save_path = self.io_manager.generate_model_save_path("temp_model_name") - self.policy.save(save_path) - def step(self): """ Perform one step of the simulation/agent loop. @@ -210,7 +143,7 @@ class PrimaiteSession: def calculate_truncated(self) -> bool: """Calculate whether the episode is truncated.""" current_step = self.step_counter - max_steps = self.training_options.max_steps_per_episode + max_steps = self.options.max_episode_length if current_step >= max_steps: return True return False @@ -227,8 +160,8 @@ class PrimaiteSession: return NotImplemented @classmethod - def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": - """Create a PrimaiteSession object from a config dictionary. + def from_config(cls, cfg: Dict) -> "PrimaiteGame": + """Create a PrimaiteGame object from a config dictionary. The config dictionary should have the following top-level keys: 1. training_config: options for training the RL agent. @@ -243,23 +176,16 @@ class PrimaiteSession: :return: A PrimaiteSession object. :rtype: PrimaiteSession """ - sess = cls() - sess.options = PrimaiteSessionOptions( - ports=cfg["game_config"]["ports"], - protocols=cfg["game_config"]["protocols"], - ) - sess.training_options = TrainingOptions(**cfg["training_config"]) + game = cls() + game.options = PrimaiteGameOptions(cfg["game"]) - # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - - sim = sess.simulation + # 1. create simulation + sim = game.simulation net = sim.network - sess.ref_map_nodes: Dict[str, Node] = {} - sess.ref_map_services: Dict[str, Service] = {} - sess.ref_map_links: Dict[str, Link] = {} + game.ref_map_nodes: Dict[str, Node] = {} + game.ref_map_services: Dict[str, Service] = {} + game.ref_map_links: Dict[str, Link] = {} nodes_cfg = cfg["simulation"]["network"]["nodes"] links_cfg = cfg["simulation"]["network"]["links"] @@ -323,7 +249,7 @@ class PrimaiteSession: print(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] - sess.ref_map_services[service_ref] = new_service + game.ref_map_services[service_ref] = new_service else: print(f"service type not found {service_type}") # service-dependent options @@ -348,7 +274,7 @@ class PrimaiteSession: if application_type in application_types_mapping: new_node.software_manager.install(application_types_mapping[application_type]) new_application = new_node.software_manager.software[application_type] - sess.ref_map_applications[application_ref] = new_application + game.ref_map_applications[application_ref] = new_application else: print(f"application type not found {application_type}") if "nics" in node_cfg: @@ -357,7 +283,7 @@ class PrimaiteSession: net.add_node(new_node) new_node.power_on() - sess.ref_map_nodes[ + game.ref_map_nodes[ node_ref ] = ( new_node.uuid @@ -365,8 +291,8 @@ class PrimaiteSession: # 2. create links between nodes for link_cfg in links_cfg: - node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]] - node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]] + node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]] + node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]] if isinstance(node_a, Switch): endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]] else: @@ -376,7 +302,7 @@ class PrimaiteSession: else: endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]] new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b) - sess.ref_map_links[link_cfg["ref"]] = new_link.uuid + game.ref_map_links[link_cfg["ref"]] = new_link.uuid # 3. create agents game_cfg = cfg["game_config"] @@ -390,14 +316,14 @@ class PrimaiteSession: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg, sess) + obs_space = ObservationManager.from_config(observation_space_cfg, game) # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] # if a list of nodes is defined, convert them from node references to node UUIDs for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}): if "node_ref" in action_node_option: - node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]] + node_uuid = game.ref_map_nodes[action_node_option["node_ref"]] action_space_cfg["options"]["node_uuids"].append(node_uuid) # Each action space can potentially have a different list of nodes that it can apply to. Therefore, # we will pass node_uuids as a part of the action space config. @@ -409,12 +335,12 @@ class PrimaiteSession: if "options" in action_config: if "target_router_ref" in action_config["options"]: _target = action_config["options"]["target_router_ref"] - action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target] + action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target] - action_space = ActionManager.from_config(sess, action_space_cfg) + action_space = ActionManager.from_config(game, action_space_cfg) # CREATE REWARD FUNCTION - rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) + rew_function = RewardFunction.from_config(reward_function_cfg, session=game) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": @@ -425,7 +351,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) + game.agents.append(new_agent) elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], @@ -433,8 +359,8 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) - sess.rl_agents.append(new_agent) + game.agents.append(new_agent) + game.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -442,16 +368,8 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, ) - sess.agents.append(new_agent) + game.agents.append(new_agent) else: print("agent type not found") - # CREATE ENVIRONMENT - sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) - - # CREATE POLICY - sess.policy = PolicyABC.from_config(sess.training_options, session=sess) - if agent_load_path: - sess.policy.load(Path(agent_load_path)) - - return sess + return game diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 249c3b52..10af44b1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.game.game import PrimaiteGame, TrainingOptions class PolicyABC(ABC): @@ -32,7 +32,7 @@ class PolicyABC(ABC): return @abstractmethod - def __init__(self, session: "PrimaiteSession") -> None: + def __init__(self, session: "PrimaiteGame") -> None: """ Initialize a reinforcement learning policy. @@ -41,7 +41,7 @@ class PolicyABC(ABC): :param agents: The agents to train. :type agents: List[RLAgent] """ - self.session: "PrimaiteSession" = session + self.session: "PrimaiteGame" = session """Reference to the session.""" @abstractmethod @@ -69,7 +69,7 @@ class PolicyABC(ABC): pass @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC": """ Create an RL policy from a config by calling the relevant subclass's from_config method. diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index 6e9e1096..7828ccc7 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -1,26 +1,23 @@ from pathlib import Path -from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union +from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING import gymnasium from gymnasium.core import ActType, ObsType -from primaite.game.environment import PrimaiteGymEnv from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.agent.interface import ProxyAgent - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.game.game import PrimaiteGame + from primaite.session.session import TrainingOptions import ray -from ray.rllib.algorithms import Algorithm, ppo -from ray.rllib.algorithms.ppo import PPOConfig -from ray.tune.registry import register_env +from ray.rllib.algorithms import ppo class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): """Single agent RL policy using Ray RLLib.""" - def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) ray.init() @@ -71,21 +68,23 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - for ep in range(n_episodes): res = self._algo.train() print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}") def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate the agent.""" raise NotImplementedError def save(self, save_path: Path) -> None: - raise NotImplementedError + """Save the policy to a file.""" + self._algo.save(save_path) def load(self, model_path: Path) -> None: + """Load policy parameters from a file.""" raise NotImplementedError @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy": """Create a policy from a config.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index a4870054..de14ed0c 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.session import PrimaiteSession, TrainingOptions + from primaite.game.game import PrimaiteGame, TrainingOptions class SB3Policy(PolicyABC, identifier="SB3"): """Single agent RL policy using stable baselines 3.""" - def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) @@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"): self._agent = self._agent_class.load(model_path, env=self.session.env) @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy": """Create an agent from config file.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/main.py b/src/primaite/main.py index 1699fe51..5bc76ca2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -6,7 +6,7 @@ from typing import Optional, Union from primaite import getLogger from primaite.config.load import load -from primaite.game.session import PrimaiteSession +from primaite.game.game import PrimaiteGame # from primaite.primaite_session import PrimaiteSession @@ -32,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) + sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() diff --git a/src/primaite/notebooks/train_rllib_single_agent.ipynb b/src/primaite/notebooks/train_rllib_single_agent.ipynb new file mode 100644 index 00000000..3b608a52 --- /dev/null +++ b/src/primaite/notebooks/train_rllib_single_agent.ipynb @@ -0,0 +1,129 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2023-11-18 09:06:45,876\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-18 09:06:48,446\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2023-11-18 09:06:48,692\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n" + ] + } + ], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "from primaite.game.environment import PrimaiteGymEnv\n", + "import yaml" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.config.load import example_config_path" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "installing DNSServer on node domain_controller\n", + "installing DatabaseClient on node web_server\n", + "installing WebServer on node web_server\n", + "installing DatabaseService on node database_server\n", + "service type not found DatabaseBackup\n", + "installing DataManipulationBot on node client_1\n", + "installing DNSClient on node client_1\n", + "installing DNSClient on node client_2\n" + ] + } + ], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "sess = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "env = sess.env" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/session/__init__.py b/src/primaite/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/io.py b/src/primaite/session/io.py similarity index 100% rename from src/primaite/game/io.py rename to src/primaite/session/io.py diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py new file mode 100644 index 00000000..d7bc3f99 --- /dev/null +++ b/src/primaite/session/session.py @@ -0,0 +1,92 @@ +from enum import Enum +from typing import Dict, List, Literal, Optional + +from pydantic import BaseModel, ConfigDict + +from primaite.game.environment import PrimaiteGymEnv + +# from primaite.game.game import PrimaiteGame +from primaite.game.policy.policy import PolicyABC +from primaite.session.io import SessionIO, SessionIOSettings + + +class TrainingOptions(BaseModel): + """Options for training the RL agent.""" + + model_config = ConfigDict(extra="forbid") + + rl_framework: Literal["SB3", "RLLIB_single_agent"] + rl_algorithm: Literal["PPO", "A2C"] + n_learn_episodes: int + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int + # checkpoint_freq: Optional[int] = None + deterministic_eval: bool + seed: Optional[int] + n_agents: int + agent_references: List[str] + + +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" + + +class PrimaiteSession: + """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" + + def __init__(self): + """Initialise PrimaiteSession object.""" + self.training_options: TrainingOptions + """Options specific to agent training.""" + + self.mode: SessionMode = SessionMode.MANUAL + """Current session mode.""" + + self.env: PrimaiteGymEnv + """The environment that the agent can consume. Could be PrimaiteEnv.""" + + self.policy: PolicyABC + """The reinforcement learning policy.""" + + self.io_manager = SessionIO() + """IO manager for the session.""" + + def start_session(self) -> None: + """Commence the training/eval session.""" + self.mode = SessionMode.TRAIN + n_learn_episodes = self.training_options.n_learn_episodes + n_eval_episodes = self.training_options.n_eval_episodes + max_steps_per_episode = self.training_options.max_steps_per_episode + + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn( + n_episodes=n_learn_episodes, + timesteps_per_episode=max_steps_per_episode, + ) + self.save_models() + + self.mode = SessionMode.EVAL + if n_eval_episodes > 0: + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL + + def save_models(self) -> None: + """Save the RL models.""" + save_path = self.io_manager.generate_model_save_path("temp_model_name") + self.policy.save(save_path) + + @classmethod + def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": + """Create a PrimaiteSession object from a config dictionary.""" + sess = cls() + + sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) diff --git a/tests/conftest.py b/tests/conftest.py index 6a65b12f..24001ffc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ import pytest import yaml from primaite import getLogger -from primaite.game.session import PrimaiteSession +from primaite.game.game import PrimaiteGame # from primaite.environment.primaite_env import Primaite # from primaite.primaite_session import PrimaiteSession @@ -74,7 +74,7 @@ def file_system() -> FileSystem: # PrimAITE v2 stuff -class TempPrimaiteSession(PrimaiteSession): +class TempPrimaiteSession(PrimaiteGame): """ A temporary PrimaiteSession class.