Separate game, environment, and session
This commit is contained in:
@@ -2,7 +2,7 @@ training_config:
|
||||
rl_framework: RLLIB_single_agent
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 25
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
|
||||
@@ -20,7 +20,7 @@ from primaite.simulator.sim_container import Simulation
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
@@ -559,7 +559,7 @@ class ActionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: "PrimaiteSession", # reference to session for looking up stuff
|
||||
session: "PrimaiteGame", # reference to session for looking up stuff
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
@@ -599,7 +599,7 @@ class ActionManager:
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.session: "PrimaiteGame" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
@@ -826,7 +826,7 @@ class ActionManager:
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
|
||||
def from_config(cls, session: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
@@ -37,7 +37,7 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
|
||||
@@ -91,7 +91,7 @@ class FileObservation(AbstractObservation):
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
@@ -149,7 +149,7 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
|
||||
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
|
||||
@@ -219,7 +219,7 @@ class LinkObservation(AbstractObservation):
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
@@ -310,7 +310,7 @@ class FolderObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
|
||||
@@ -376,9 +376,7 @@ class NicObservation(AbstractObservation):
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
|
||||
) -> "NicObservation":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
@@ -515,7 +513,7 @@ class NodeObservation(AbstractObservation):
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
session: "PrimaiteSession",
|
||||
session: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
@@ -694,7 +692,7 @@ class AclObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
@@ -740,7 +738,7 @@ class NullObservation(AbstractObservation):
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
@@ -836,7 +834,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
@@ -907,7 +905,7 @@ class UC2RedObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
@@ -966,7 +964,7 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
|
||||
@@ -34,7 +34,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
@@ -47,7 +47,7 @@ class AbstractReward:
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
|
||||
def from_config(cls, config: dict, session: "PrimaiteGame") -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
@@ -68,7 +68,7 @@ class DummyReward(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
|
||||
def from_config(cls, config: dict, session: "PrimaiteGame") -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
@@ -119,7 +119,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
@@ -193,7 +193,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
@@ -265,7 +265,7 @@ class RewardFunction:
|
||||
return self.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
|
||||
@@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
@@ -17,10 +17,10 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
def __init__(self, session: "PrimaiteGame", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.session: "PrimaiteGame" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -13,9 +9,6 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.environment import PrimaiteGymEnv
|
||||
from primaite.game.io import SessionIO, SessionIOSettings
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -37,7 +30,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
class PrimaiteGameOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
|
||||
@@ -46,37 +39,17 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
max_episode_length: int = 256
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
class PrimaiteGame:
|
||||
"""
|
||||
Primaite game encapsulates the simulation and agents which interact with it.
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
|
||||
Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise a PrimaiteSession object."""
|
||||
@@ -95,15 +68,9 @@ class PrimaiteSession:
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteSessionOptions
|
||||
self.options: PrimaiteGameOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
|
||||
@@ -116,40 +83,6 @@ class PrimaiteSession:
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -210,7 +143,7 @@ class PrimaiteSession:
|
||||
def calculate_truncated(self) -> bool:
|
||||
"""Calculate whether the episode is truncated."""
|
||||
current_step = self.step_counter
|
||||
max_steps = self.training_options.max_steps_per_episode
|
||||
max_steps = self.options.max_episode_length
|
||||
if current_step >= max_steps:
|
||||
return True
|
||||
return False
|
||||
@@ -227,8 +160,8 @@ class PrimaiteSession:
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
|
||||
"""Create a PrimaiteGame object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent.
|
||||
@@ -243,23 +176,16 @@ class PrimaiteSession:
|
||||
:return: A PrimaiteSession object.
|
||||
:rtype: PrimaiteSession
|
||||
"""
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(cfg["game"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
sim = sess.simulation
|
||||
# 1. create simulation
|
||||
sim = game.simulation
|
||||
net = sim.network
|
||||
|
||||
sess.ref_map_nodes: Dict[str, Node] = {}
|
||||
sess.ref_map_services: Dict[str, Service] = {}
|
||||
sess.ref_map_links: Dict[str, Link] = {}
|
||||
game.ref_map_nodes: Dict[str, Node] = {}
|
||||
game.ref_map_services: Dict[str, Service] = {}
|
||||
game.ref_map_links: Dict[str, Link] = {}
|
||||
|
||||
nodes_cfg = cfg["simulation"]["network"]["nodes"]
|
||||
links_cfg = cfg["simulation"]["network"]["links"]
|
||||
@@ -323,7 +249,7 @@ class PrimaiteSession:
|
||||
print(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
sess.ref_map_services[service_ref] = new_service
|
||||
game.ref_map_services[service_ref] = new_service
|
||||
else:
|
||||
print(f"service type not found {service_type}")
|
||||
# service-dependent options
|
||||
@@ -348,7 +274,7 @@ class PrimaiteSession:
|
||||
if application_type in application_types_mapping:
|
||||
new_node.software_manager.install(application_types_mapping[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
sess.ref_map_applications[application_ref] = new_application
|
||||
game.ref_map_applications[application_ref] = new_application
|
||||
else:
|
||||
print(f"application type not found {application_type}")
|
||||
if "nics" in node_cfg:
|
||||
@@ -357,7 +283,7 @@ class PrimaiteSession:
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
sess.ref_map_nodes[
|
||||
game.ref_map_nodes[
|
||||
node_ref
|
||||
] = (
|
||||
new_node.uuid
|
||||
@@ -365,8 +291,8 @@ class PrimaiteSession:
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
@@ -376,7 +302,7 @@ class PrimaiteSession:
|
||||
else:
|
||||
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
|
||||
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
|
||||
# 3. create agents
|
||||
game_cfg = cfg["game_config"]
|
||||
@@ -390,14 +316,14 @@ class PrimaiteSession:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, game)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
|
||||
if "node_ref" in action_node_option:
|
||||
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
|
||||
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
|
||||
action_space_cfg["options"]["node_uuids"].append(node_uuid)
|
||||
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
|
||||
# we will pass node_uuids as a part of the action space config.
|
||||
@@ -409,12 +335,12 @@ class PrimaiteSession:
|
||||
if "options" in action_config:
|
||||
if "target_router_ref" in action_config["options"]:
|
||||
_target = action_config["options"]["target_router_ref"]
|
||||
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
|
||||
action_config["options"]["target_router_uuid"] = game.ref_map_nodes[_target]
|
||||
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=game)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
@@ -425,7 +351,7 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
game.agents.append(new_agent)
|
||||
elif agent_type == "ProxyAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -433,8 +359,8 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agents.append(new_agent)
|
||||
game.agents.append(new_agent)
|
||||
game.rl_agents.append(new_agent)
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -442,16 +368,8 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
game.agents.append(new_agent)
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
return game
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
from primaite.game.game import PrimaiteGame, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
@@ -32,7 +32,7 @@ class PolicyABC(ABC):
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
def __init__(self, session: "PrimaiteGame") -> None:
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
@@ -41,7 +41,7 @@ class PolicyABC(ABC):
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.session: "PrimaiteGame" = session
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -69,7 +69,7 @@ class PolicyABC(ABC):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union
|
||||
from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.environment import PrimaiteGymEnv
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms import Algorithm, ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
"""Single agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
ray.init()
|
||||
|
||||
@@ -71,21 +68,23 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
|
||||
for ep in range(n_episodes):
|
||||
res = self._algo.train()
|
||||
print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}")
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
raise NotImplementedError
|
||||
"""Save the policy to a file."""
|
||||
self._algo.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
@@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
from primaite.game.game import PrimaiteGame, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
@@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import load
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
@@ -32,7 +32,7 @@ def run(
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
|
||||
129
src/primaite/notebooks/train_rllib_single_agent.ipynb
Normal file
129
src/primaite/notebooks/train_rllib_single_agent.ipynb
Normal file
@@ -0,0 +1,129 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"2023-11-18 09:06:45,876\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
|
||||
"2023-11-18 09:06:48,446\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
|
||||
"2023-11-18 09:06:48,692\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.game.environment import PrimaiteGymEnv\n",
|
||||
"import yaml"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.config.load import example_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"installing DNSServer on node domain_controller\n",
|
||||
"installing DatabaseClient on node web_server\n",
|
||||
"installing WebServer on node web_server\n",
|
||||
"installing DatabaseService on node database_server\n",
|
||||
"service type not found DatabaseBackup\n",
|
||||
"installing DataManipulationBot on node client_1\n",
|
||||
"installing DNSClient on node client_1\n",
|
||||
"installing DNSClient on node client_2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"sess = PrimaiteGame.from_config(cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = sess.env"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<primaite.game.environment.PrimaiteGymEnv at 0x7fad7190d7b0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
0
src/primaite/session/__init__.py
Normal file
0
src/primaite/session/__init__.py
Normal file
92
src/primaite/session/session.py
Normal file
92
src/primaite/session/session.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.game.environment import PrimaiteGymEnv
|
||||
|
||||
# from primaite.game.game import PrimaiteGame
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary."""
|
||||
sess = cls()
|
||||
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
@@ -74,7 +74,7 @@ def file_system() -> FileSystem:
|
||||
|
||||
|
||||
# PrimAITE v2 stuff
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
class TempPrimaiteSession(PrimaiteGame):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user