Update data manipulation bot
This commit is contained in:
@@ -29,6 +29,15 @@ class _PrimaitePaths:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
|
||||
self.user_home_path = self.generate_user_home_path()
|
||||
self.user_sessions_path = self.generate_user_sessions_path()
|
||||
self.user_config_path = self.generate_user_config_path()
|
||||
self.user_notebooks_path = self.generate_user_notebooks_path()
|
||||
self.app_home_path = self.generate_app_home_path()
|
||||
self.app_config_dir_path = self.generate_app_config_dir_path()
|
||||
self.app_config_file_path = self.generate_app_config_file_path()
|
||||
self.app_log_dir_path = self.generate_app_log_dir_path()
|
||||
self.app_log_file_path = self.generate_app_log_file_path()
|
||||
|
||||
def _get_dirs_properties(self) -> List[str]:
|
||||
class_items = self.__class__.__dict__.items()
|
||||
@@ -43,55 +52,47 @@ class _PrimaitePaths:
|
||||
for p in self._get_dirs_properties():
|
||||
getattr(self, p)
|
||||
|
||||
@property
|
||||
def user_home_path(self) -> Path:
|
||||
def generate_user_home_path(self) -> Path:
|
||||
"""The PrimAITE user home path."""
|
||||
path = Path.home() / "primaite" / __version__
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_sessions_path(self) -> Path:
|
||||
def generate_user_sessions_path(self) -> Path:
|
||||
"""The PrimAITE user sessions path."""
|
||||
path = self.user_home_path / "sessions"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_config_path(self) -> Path:
|
||||
def generate_user_config_path(self) -> Path:
|
||||
"""The PrimAITE user config path."""
|
||||
path = self.user_home_path / "config"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_notebooks_path(self) -> Path:
|
||||
def generate_user_notebooks_path(self) -> Path:
|
||||
"""The PrimAITE user notebooks path."""
|
||||
path = self.user_home_path / "notebooks"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_home_path(self) -> Path:
|
||||
def generate_app_home_path(self) -> Path:
|
||||
"""The PrimAITE app home path."""
|
||||
path = self._dirs.user_data_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_dir_path(self) -> Path:
|
||||
def generate_app_config_dir_path(self) -> Path:
|
||||
"""The PrimAITE app config directory path."""
|
||||
path = self._dirs.user_config_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_file_path(self) -> Path:
|
||||
def generate_app_config_file_path(self) -> Path:
|
||||
"""The PrimAITE app config file path."""
|
||||
return self.app_config_dir_path / "primaite_config.yaml"
|
||||
|
||||
@property
|
||||
def app_log_dir_path(self) -> Path:
|
||||
def generate_app_log_dir_path(self) -> Path:
|
||||
"""The PrimAITE app log directory path."""
|
||||
if sys.platform == "win32":
|
||||
path = self.app_home_path / "logs"
|
||||
@@ -100,8 +101,7 @@ class _PrimaitePaths:
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_log_file_path(self) -> Path:
|
||||
def generate_app_log_file_path(self) -> Path:
|
||||
"""The PrimAITE app log file path."""
|
||||
return self.app_log_dir_path / "primaite.log"
|
||||
|
||||
@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"WARNING": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
|
||||
@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from arcd_gate.cli import setup as gate_setup
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import reset_demo_notebooks, reset_example_configs
|
||||
|
||||
@@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Setting up ARCD GATE...")
|
||||
gate_setup()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
config: Optional[str] = None,
|
||||
agent_load_file: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
@@ -131,16 +127,10 @@ def session(
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from threading import Thread
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.main import run
|
||||
from primaite.utils.start_gate_server import start_gate_server
|
||||
|
||||
server_thread = Thread(target=start_gate_server)
|
||||
server_thread.start()
|
||||
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
print(config)
|
||||
run(config_path=config)
|
||||
run(config_path=config, agent_load_path=agent_load_file)
|
||||
|
||||
@@ -2,10 +2,17 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 20
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 20
|
||||
n_eval_steps: 128
|
||||
n_learn_episodes: 25
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
|
||||
|
||||
game_config:
|
||||
@@ -107,7 +114,7 @@ game_config:
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: GATERLAgent
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GATERLAgent(AbstractGATEAgent):
|
||||
...
|
||||
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
|
||||
# because when we are supporting MARL, the actions form multiple agents will have to be batched
|
||||
|
||||
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
|
||||
# with the actions for those agents.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.most_recent_action
|
||||
@@ -4,10 +4,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -55,7 +56,7 @@ class AbstractAgent(ABC):
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings],
|
||||
) -> None:
|
||||
@@ -72,21 +73,21 @@ class AbstractAgent(ABC):
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
return self.observation_manager.update(state)
|
||||
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
Use the reward function to calculate a reward from the state.
|
||||
|
||||
@@ -95,10 +96,10 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.calculate(state)
|
||||
return self.reward_function.update(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
@@ -111,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
@@ -119,7 +120,7 @@ class AbstractAgent(ABC):
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
@@ -132,7 +133,7 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
@@ -142,7 +143,47 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.most_recent_action)
|
||||
|
||||
def store_action(self, action: ActType):
|
||||
"""
|
||||
Store the most recent action taken by the agent.
|
||||
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
|
||||
pass
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
@@ -947,15 +948,17 @@ class ObservationSpace:
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.current_observation: ObsType
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
"""
|
||||
return self.obs.observe(state)
|
||||
self.current_observation = self.obs.observe(state)
|
||||
return self.current_observation
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
@@ -963,7 +966,7 @@ class ObservationSpace:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
|
||||
@@ -26,7 +26,7 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
|
||||
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
@@ -238,6 +238,7 @@ class RewardFunction:
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float
|
||||
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
@@ -249,7 +250,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def update(self, state: Dict) -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -260,7 +261,8 @@ class RewardFunction:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
return total
|
||||
self.current_reward = total
|
||||
return self.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
|
||||
63
src/primaite/game/io.py
Normal file
63
src/primaite/game/io.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class SessionIOSettings(BaseModel):
|
||||
"""Schema for session IO settings."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
save_final_model: bool = True
|
||||
"""Whether to save the final model right at the end of training."""
|
||||
save_checkpoints: bool = False
|
||||
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
|
||||
checkpoint_interval: int = 10
|
||||
"""How often to save a checkpoint model (if save_checkpoints is True)."""
|
||||
save_logs: bool = True
|
||||
"""Whether to save logs"""
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
"""
|
||||
Class for managing session IO.
|
||||
|
||||
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
|
||||
# set global SIM_OUTPUT path
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
return session_path
|
||||
|
||||
def generate_model_save_path(self, agent_name: str) -> Path:
|
||||
"""Return the path where the final model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_final"
|
||||
|
||||
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
|
||||
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
|
||||
3
src/primaite/game/policy/__init__.py
Normal file
3
src/primaite/game/policy/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from primaite.game.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy"]
|
||||
84
src/primaite/game/policy/policy.py
Normal file
84
src/primaite/game/policy/policy.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
"""Base class for reinforcement learning agents."""
|
||||
|
||||
_registry: Dict[str, Type["PolicyABC"]] = {}
|
||||
"""
|
||||
Registry of policy types, keyed by name.
|
||||
|
||||
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a policy subclass.
|
||||
|
||||
:param name: Identifier used by from_config to create an instance of the policy.
|
||||
:type name: str
|
||||
:raises ValueError: When attempting to create a policy with a duplicate name.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
:param session: The session context.
|
||||
:type session: PrimaiteSession
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
"""Load agent from a file."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
Subclasses should not call super().from_config(), they should just handle creation form config.
|
||||
"""
|
||||
# Assume that basically the contents of training_config are passed into here.
|
||||
# I should really define a config schema class using pydantic.
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
|
||||
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass
|
||||
80
src/primaite/game/policy/sb3.py
Normal file
80
src/primaite/game/policy/sb3.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self._agent_class: Type[Union[PPO, A2C]]
|
||||
if algorithm == "PPO":
|
||||
self._agent_class = PPO
|
||||
policy = PPO_MLP
|
||||
elif algorithm == "A2C":
|
||||
self._agent_class = A2C
|
||||
policy = A2C_MLP
|
||||
else:
|
||||
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
|
||||
self._agent = self._agent_class(
|
||||
policy=policy,
|
||||
env=self.session.env,
|
||||
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
if self.session.io_manager.settings.save_checkpoints:
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
|
||||
save_path=self.session.io_manager.generate_model_save_path("sb3"),
|
||||
name_prefix="sb3_model",
|
||||
)
|
||||
else:
|
||||
checkpoint_callback = None
|
||||
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
reward_data = evaluate_policy(
|
||||
self._agent,
|
||||
self.session.env,
|
||||
n_eval_episodes=n_episodes,
|
||||
deterministic=deterministic,
|
||||
return_episode_rewards=True,
|
||||
)
|
||||
print(reward_data)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""
|
||||
Save the current policy parameters.
|
||||
|
||||
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
|
||||
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
|
||||
Therefore, this method is only used to save the final model.
|
||||
"""
|
||||
self._agent.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load agent from a checkpoint."""
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -1,18 +1,21 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
|
||||
from arcd_gate.client.gate_client import ActType, GATEClient
|
||||
from gymnasium import spaces
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces.utils import flatten, flatten_space
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.io import SessionIO, SessionIOSettings
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -34,109 +37,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
|
||||
"""
|
||||
Create a new GATE client for PrimAITE.
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
:param parent_session: The parent session object.
|
||||
:type parent_session: PrimaiteSession
|
||||
:param service_port: The port on which the GATE service is running.
|
||||
:type service_port: int, optional
|
||||
"""
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session: "PrimaiteSession" = parent_session
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
"""The reinforcement learning framework to use."""
|
||||
return self.parent_session.training_options.rl_framework
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.session.apply_agent_actions()
|
||||
self.session.advance_timestep()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
"""The reinforcement learning algorithm to use."""
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> Optional[int]:
|
||||
"""The seed to use for the environment's random number generator."""
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
"""The number of episodes in each learning run."""
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
"""The number of steps in each learning episode."""
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
"""The number of episodes in each evaluation run."""
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
"""The number of steps in each evaluation episode."""
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
"""The gym action space of the agent."""
|
||||
return self.parent_session.rl_agent.action_space.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""The gymnasium observation space of the agent."""
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
"""Take a step in the environment.
|
||||
|
||||
This method is called by GATE to advance the simulation by one timestep.
|
||||
|
||||
:param action: The agent's action.
|
||||
:type action: ActType
|
||||
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
|
||||
:rtype: Tuple[ObsType, float, bool, bool, Dict]
|
||||
"""
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate(state)
|
||||
term = False
|
||||
trunc = False
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.session.calculate_truncated()
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment.
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
This method is called when the environment is initialized and at the end of each episode.
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
:param seed: The seed to use for the environment's random number generator.
|
||||
:type seed: int, optional
|
||||
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
|
||||
compatibility with GATE.
|
||||
:type options: dict[str, Any], optional
|
||||
:return: The initial observation and an empty info dictionary.
|
||||
:rtype: Tuple[ObsType, Dict]
|
||||
"""
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
return obs, {}
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.space
|
||||
|
||||
def close(self):
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
self.parent_session.close()
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
@@ -146,6 +102,8 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
@@ -153,48 +111,105 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
rl_framework: str
|
||||
rl_algorithm: str
|
||||
seed: Optional[int]
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int
|
||||
n_eval_steps: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise a PrimaiteSession object."""
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self._simulation_initial_state = deepcopy(self.simulation)
|
||||
"""The Simulation original state (deepcopy of the original Simulation)."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.rl_agent: AbstractAgent
|
||||
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
|
||||
self.rl_agents: List[ProxyAgent] = []
|
||||
"""Subset of agent list including only the reinforcement learning agents."""
|
||||
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteSessionOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_applications: Dict[str, Application] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
|
||||
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
self.gate_client.start()
|
||||
"""Commence the training session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -208,57 +223,76 @@ class PrimaiteSession:
|
||||
4. Each agent chooses an action based on the observation.
|
||||
5. Each agent converts the action to a request.
|
||||
6. The simulation applies the requests.
|
||||
|
||||
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
|
||||
interacts with should implement a step method that calls methods used by this method. For example, if using a
|
||||
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
|
||||
``self.apply_agent_actions()``.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
return self.simulation.describe_state()
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
# 5. primaite session asks agents to produce an action based on most recent state
|
||||
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
|
||||
sim_state = self.simulation.describe_state()
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.convert_state_to_obs(sim_state)
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
for agent in self.agents:
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.calculate_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
_LOGGER.debug("Getting agent action")
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
|
||||
self.simulation.apply_request(agent_request)
|
||||
|
||||
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
self.step_counter += 1
|
||||
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
|
||||
def calculate_truncated(self) -> bool:
|
||||
"""Calculate whether the episode is truncated."""
|
||||
current_step = self.step_counter
|
||||
max_steps = self.training_options.max_steps_per_episode
|
||||
if current_step >= max_steps:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
return NotImplemented
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
|
||||
self.simulation = deepcopy(self._simulation_initial_state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent. Used by GATE.
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
@@ -276,6 +310,11 @@ class PrimaiteSession:
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
@@ -425,7 +464,7 @@ class PrimaiteSession:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
@@ -479,16 +518,15 @@ class PrimaiteSession:
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
new_agent = RandomAgent(
|
||||
elif agent_type == "ProxyAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
sess.rl_agents.append(new_agent)
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -501,4 +539,14 @@ class PrimaiteSession:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
|
||||
|
||||
return sess
|
||||
|
||||
@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
def run(
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
agent_load_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
@@ -31,7 +32,7 @@ def run(
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory."""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import _PRIMAITE_ROOT
|
||||
|
||||
SIM_OUTPUT = None
|
||||
"A path at the repo root dir to use temporarily for sim output testing while in dev."
|
||||
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
if not SIM_OUTPUT:
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
|
||||
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
class __SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self._path
|
||||
|
||||
@path.setter
|
||||
def path(self, new_path: Path) -> None:
|
||||
self._path = new_path
|
||||
self._path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
SIM_OUTPUT = __SimOutput()
|
||||
|
||||
@@ -957,7 +957,7 @@ class Node(SimComponent):
|
||||
if not kwargs.get("session_manager"):
|
||||
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
|
||||
if not kwargs.get("root"):
|
||||
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
|
||||
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
|
||||
if not kwargs.get("file_system"):
|
||||
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
|
||||
if not kwargs.get("software_manager"):
|
||||
|
||||
@@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network:
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
|
||||
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
|
||||
# Client 2
|
||||
client_2 = Computer(
|
||||
|
||||
@@ -2,13 +2,14 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
"""
|
||||
@@ -130,7 +131,7 @@ class DatabaseClient(Application):
|
||||
|
||||
def execute(self) -> None:
|
||||
"""Run the DatabaseClient."""
|
||||
# super().execute()
|
||||
super().execute()
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.connect()
|
||||
|
||||
@@ -148,21 +149,6 @@ class DatabaseClient(Application):
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
|
||||
def _print_data(self, data: Dict):
|
||||
"""
|
||||
Display the contents of the Folder in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
if data:
|
||||
table = PrettyTable(list(data.values())[0])
|
||||
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} Database Client"
|
||||
for row in data.values():
|
||||
table.add_row(row.values())
|
||||
print(table)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
@@ -179,5 +165,5 @@ class DatabaseClient(Application):
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
self._print_data(payload["data"])
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
return True
|
||||
|
||||
@@ -30,7 +30,7 @@ class WebBrowser(Application):
|
||||
kwargs["port"] = Port.HTTP
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.execute()
|
||||
self.run()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -135,6 +135,3 @@ class WebBrowser(Application):
|
||||
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
|
||||
self.latest_response = payload
|
||||
return True
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
@@ -75,7 +75,7 @@ class PacketCapture:
|
||||
|
||||
def _get_log_path(self) -> Path:
|
||||
"""Get the path for the log file."""
|
||||
root = SIM_OUTPUT / self.hostname
|
||||
root = SIM_OUTPUT.path / self.hostname
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self._logger_name}.log"
|
||||
|
||||
|
||||
@@ -100,8 +100,16 @@ class SoftwareManager:
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
for key, value in self.port_protocol_mapping.items():
|
||||
if value.name == software_name:
|
||||
self.port_protocol_mapping.pop(key)
|
||||
break
|
||||
for key, value in self._software_class_to_name_map.items():
|
||||
if value == software_name:
|
||||
self._software_class_to_name_map.pop(key)
|
||||
break
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
self.sys_log.info(f"Uninstalled {software_name}")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ class SysLog:
|
||||
JSON-like messages.
|
||||
"""
|
||||
log_path = self._get_log_path()
|
||||
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -81,7 +80,7 @@ class SysLog:
|
||||
|
||||
:return: Path object representing the location of the log file.
|
||||
"""
|
||||
root = SIM_OUTPUT / self.hostname
|
||||
root = SIM_OUTPUT.path / self.hostname
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self.hostname}_sys.log"
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from sqlite3 import OperationalError
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -19,7 +15,7 @@ class DatabaseService(Service):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
@@ -41,38 +37,6 @@ class DatabaseService(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._db_file: File
|
||||
self._create_db_file()
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
self._conn = sqlite3.connect(self._db_file.sim_path)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def tables(self) -> List[str]:
|
||||
"""
|
||||
Get a list of table names present in the database.
|
||||
|
||||
:return: List of table names.
|
||||
"""
|
||||
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
|
||||
results = self._process_sql(sql, None)
|
||||
if isinstance(results["data"], dict):
|
||||
return list(results["data"].keys())
|
||||
return []
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Prints a list of table names in the database using PrettyTable.
|
||||
|
||||
:param markdown: Whether to output the table in Markdown format.
|
||||
"""
|
||||
table = PrettyTable(["Table"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.file_system.sys_log.hostname} Database"
|
||||
for row in self.tables():
|
||||
table.add_row([row])
|
||||
print(table)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
"""
|
||||
@@ -89,8 +53,6 @@ class DatabaseService(Service):
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
self._conn.close()
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
|
||||
@@ -98,12 +60,10 @@ class DatabaseService(Service):
|
||||
response = ftp_client_service.send_file(
|
||||
dest_ip_address=self.backup_server,
|
||||
src_file_name=self._db_file.name,
|
||||
src_folder_name=self._db_file.folder.name,
|
||||
src_folder_name=self.folder.name,
|
||||
dest_folder_name=str(self.uuid),
|
||||
dest_file_name="database.db",
|
||||
real_file_path=self._db_file.sim_path,
|
||||
)
|
||||
self._connect()
|
||||
|
||||
if response:
|
||||
return True
|
||||
@@ -125,25 +85,29 @@ class DatabaseService(Service):
|
||||
dest_ip_address=self.backup_server,
|
||||
)
|
||||
|
||||
if response:
|
||||
self._conn.close()
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
self._connect()
|
||||
if not response:
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
|
||||
return self._db_file is not None
|
||||
# replace db file
|
||||
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
|
||||
self.file_system.copy_file(
|
||||
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
|
||||
)
|
||||
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
|
||||
|
||||
self.sys_log.error("Unable to restore database backup.")
|
||||
return False
|
||||
if self._db_file is None:
|
||||
self.sys_log.error("Copying database backup failed.")
|
||||
return False
|
||||
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
|
||||
return True
|
||||
|
||||
def _create_db_file(self):
|
||||
"""Creates the Simulation File and sqlite file in the file system."""
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
|
||||
self.folder = self._db_file.folder
|
||||
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
@@ -163,31 +127,32 @@ class DatabaseService(Service):
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
|
||||
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
Possible queries:
|
||||
- SELECT : returns the data
|
||||
- DELETE : deletes the data
|
||||
|
||||
:param query: The SQL query to be executed.
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
try:
|
||||
self._cursor.execute(query)
|
||||
self._conn.commit()
|
||||
except OperationalError:
|
||||
# Handle the case where the table does not exist.
|
||||
self.sys_log.error(f"{self.name}: Error, query failed")
|
||||
return {"status_code": 404, "data": {}}
|
||||
data = []
|
||||
description = self._cursor.description
|
||||
if description:
|
||||
headers = []
|
||||
for header in description:
|
||||
headers.append(header[0])
|
||||
data = self._cursor.fetchall()
|
||||
if data and headers:
|
||||
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
|
||||
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
else:
|
||||
# Invalid query
|
||||
return {"status_code": 500, "data": False}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -106,7 +106,7 @@ class WebServer(Service):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
|
||||
# get all users
|
||||
if db_client.query("SELECT * FROM user;"):
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
response.status_code = HttpStatusCode.OK
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
|
||||
def start_gate_server():
|
||||
"""Start the gate server."""
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_gate_server()
|
||||
Reference in New Issue
Block a user