Separate game, environment, and session

This commit is contained in:
Marek Wolan
2023-11-22 11:59:25 +00:00
parent 66d5612e92
commit afd64e4674
15 changed files with 304 additions and 168 deletions

View File

@@ -2,7 +2,7 @@ training_config:
rl_framework: RLLIB_single_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 25
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false

View File

@@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractAction(ABC):
@@ -559,7 +559,7 @@ class ActionManager:
def __init__(
self,
session: "PrimaiteSession", # reference to session for looking up stuff
session: "PrimaiteGame", # reference to session for looking up stuff
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
max_folders_per_node: int = 2, # allows calculating shape
@@ -599,7 +599,7 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.session: "PrimaiteSession" = session
self.session: "PrimaiteGame" = session
self.sim: Simulation = self.session.simulation
self.node_uuids: List[str] = node_uuids
self.protocols: List[str] = protocols
@@ -826,7 +826,7 @@ class ActionManager:
return nics[nic_idx]
@classmethod
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config definition.

View File

@@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractObservation(ABC):
@@ -37,7 +37,7 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
def from_config(cls, config: Dict, session: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
@@ -91,7 +91,7 @@ class FileObservation(AbstractObservation):
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
@@ -149,7 +149,7 @@ class ServiceObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
@@ -219,7 +219,7 @@ class LinkObservation(AbstractObservation):
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
@@ -310,7 +310,7 @@ class FolderObservation(AbstractObservation):
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
@@ -376,9 +376,7 @@ class NicObservation(AbstractObservation):
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
) -> "NicObservation":
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
@@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation):
def from_config(
cls,
config: Dict,
session: "PrimaiteSession",
session: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
@@ -694,7 +692,7 @@ class AclObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
@@ -740,7 +738,7 @@ class NullObservation(AbstractObservation):
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
@@ -836,7 +834,7 @@ class UC2BlueObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
@@ -907,7 +905,7 @@ class UC2RedObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
@@ -966,7 +964,7 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.

View File

@@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class AbstractReward:
@@ -47,7 +47,7 @@ class AbstractReward:
@classmethod
@abstractmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
@@ -68,7 +68,7 @@ class DummyReward(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
@@ -119,7 +119,7 @@ class DatabaseFileIntegrity(AbstractReward):
return 0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
@@ -193,7 +193,7 @@ class WebServer404Penalty(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
@@ -265,7 +265,7 @@ class RewardFunction:
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction":
"""Create a reward function from a config dictionary.
:param config: dict of options for the reward manager's constructor

View File

@@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType
from primaite.game.agent.interface import ProxyAgent
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
class PrimaiteGymEnv(gymnasium.Env):
@@ -17,10 +17,10 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
def __init__(self, session: "PrimaiteGame", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.session: "PrimaiteGame" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:

View File

@@ -1,11 +1,7 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
from typing import Dict, List
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
@@ -13,9 +9,6 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.environment import PrimaiteGymEnv
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -37,7 +30,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteSessionOptions(BaseModel):
class PrimaiteGameOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
@@ -46,37 +39,17 @@ class PrimaiteSessionOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
max_episode_length: int = 256
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
class PrimaiteGame:
"""
Primaite game encapsulates the simulation and agents which interact with it.
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment.
"""
def __init__(self):
"""Initialise a PrimaiteSession object."""
@@ -95,15 +68,9 @@ class PrimaiteSession:
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
@@ -116,40 +83,6 @@ class PrimaiteSession:
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -210,7 +143,7 @@ class PrimaiteSession:
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
max_steps = self.options.max_episode_length
if current_step >= max_steps:
return True
return False
@@ -227,8 +160,8 @@ class PrimaiteSession:
return NotImplemented
@classmethod
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
@@ -243,23 +176,16 @@ class PrimaiteSession:
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
"""
sess = cls()
sess.options = PrimaiteSessionOptions(
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
game = cls()
game.options = PrimaiteGameOptions(cfg["game"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
# 1. create simulation
sim = game.simulation
net = sim.network
sess.ref_map_nodes: Dict[str, Node] = {}
sess.ref_map_services: Dict[str, Service] = {}
sess.ref_map_links: Dict[str, Link] = {}
game.ref_map_nodes: Dict[str, Node] = {}
game.ref_map_services: Dict[str, Service] = {}
game.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
@@ -323,7 +249,7 @@ class PrimaiteSession:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
sess.ref_map_services[service_ref] = new_service
game.ref_map_services[service_ref] = new_service
else:
print(f"service type not found {service_type}")
# service-dependent options
@@ -348,7 +274,7 @@ class PrimaiteSession:
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
game.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
if "nics" in node_cfg:
@@ -357,7 +283,7 @@ class PrimaiteSession:
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[
game.ref_map_nodes[
node_ref
] = (
new_node.uuid
@@ -365,8 +291,8 @@ class PrimaiteSession:
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
@@ -376,7 +302,7 @@ class PrimaiteSession:
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
# 3. create agents
game_cfg = cfg["game_config"]
@@ -390,14 +316,14 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, game)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
@@ -409,12 +335,12 @@ class PrimaiteSession:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
action_space = ActionManager.from_config(game, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
rew_function = RewardFunction.from_config(reward_function_cfg, session=game)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
@@ -425,7 +351,7 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -433,8 +359,8 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agents.append(new_agent)
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
@@ -442,16 +368,8 @@ class PrimaiteSession:
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
game.agents.append(new_agent)
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess
return game

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.game.game import PrimaiteGame, TrainingOptions
class PolicyABC(ABC):
@@ -32,7 +32,7 @@ class PolicyABC(ABC):
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
def __init__(self, session: "PrimaiteGame") -> None:
"""
Initialize a reinforcement learning policy.
@@ -41,7 +41,7 @@ class PolicyABC(ABC):
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
self.session: "PrimaiteGame" = session
"""Reference to the session."""
@abstractmethod
@@ -69,7 +69,7 @@ class PolicyABC(ABC):
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.

View File

@@ -1,26 +1,23 @@
from pathlib import Path
from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union
from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING
import gymnasium
from gymnasium.core import ActType, ObsType
from primaite.game.environment import PrimaiteGymEnv
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.agent.interface import ProxyAgent
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.game.game import PrimaiteGame
from primaite.session.session import TrainingOptions
import ray
from ray.rllib.algorithms import Algorithm, ppo
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from ray.rllib.algorithms import ppo
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
ray.init()
@@ -71,21 +68,23 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
res = self._algo.train()
print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}")
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
raise NotImplementedError
def save(self, save_path: Path) -> None:
raise NotImplementedError
"""Save the policy to a file."""
self._algo.save(save_path)
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
from primaite.game.game import PrimaiteGame, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
@@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"):
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -6,7 +6,7 @@ from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import load
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
# from primaite.primaite_session import PrimaiteSession
@@ -32,7 +32,7 @@ def run(
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()

View File

@@ -0,0 +1,129 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2023-11-18 09:06:45,876\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-18 09:06:48,446\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-18 09:06:48,692\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n"
]
}
],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"from primaite.game.environment import PrimaiteGymEnv\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from primaite.config.load import example_config_path"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"installing DNSServer on node domain_controller\n",
"installing DatabaseClient on node web_server\n",
"installing WebServer on node web_server\n",
"installing DatabaseService on node database_server\n",
"service type not found DatabaseBackup\n",
"installing DataManipulationBot on node client_1\n",
"installing DNSClient on node client_1\n",
"installing DNSClient on node client_2\n"
]
}
],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"sess = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"env = sess.env"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<primaite.game.environment.PrimaiteGymEnv at 0x7fad7190d7b0>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

View File

@@ -0,0 +1,92 @@
from enum import Enum
from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, ConfigDict
from primaite.game.environment import PrimaiteGymEnv
# from primaite.game.game import PrimaiteGame
from primaite.game.policy.policy import PolicyABC
from primaite.session.io import SessionIO, SessionIOSettings
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training/eval session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
sess = cls()
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)

View File

@@ -11,7 +11,7 @@ import pytest
import yaml
from primaite import getLogger
from primaite.game.session import PrimaiteSession
from primaite.game.game import PrimaiteGame
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession
@@ -74,7 +74,7 @@ def file_system() -> FileSystem:
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteSession):
class TempPrimaiteSession(PrimaiteGame):
"""
A temporary PrimaiteSession class.