Update data manipulation bot

This commit is contained in:
Jake Walker
2023-11-24 10:33:19 +00:00
41 changed files with 3648 additions and 499 deletions

View File

@@ -29,6 +29,15 @@ class _PrimaitePaths:
def __init__(self) -> None:
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
self.user_home_path = self.generate_user_home_path()
self.user_sessions_path = self.generate_user_sessions_path()
self.user_config_path = self.generate_user_config_path()
self.user_notebooks_path = self.generate_user_notebooks_path()
self.app_home_path = self.generate_app_home_path()
self.app_config_dir_path = self.generate_app_config_dir_path()
self.app_config_file_path = self.generate_app_config_file_path()
self.app_log_dir_path = self.generate_app_log_dir_path()
self.app_log_file_path = self.generate_app_log_file_path()
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
@@ -43,55 +52,47 @@ class _PrimaitePaths:
for p in self._get_dirs_properties():
getattr(self, p)
@property
def user_home_path(self) -> Path:
def generate_user_home_path(self) -> Path:
"""The PrimAITE user home path."""
path = Path.home() / "primaite" / __version__
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_sessions_path(self) -> Path:
def generate_user_sessions_path(self) -> Path:
"""The PrimAITE user sessions path."""
path = self.user_home_path / "sessions"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_config_path(self) -> Path:
def generate_user_config_path(self) -> Path:
"""The PrimAITE user config path."""
path = self.user_home_path / "config"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_notebooks_path(self) -> Path:
def generate_user_notebooks_path(self) -> Path:
"""The PrimAITE user notebooks path."""
path = self.user_home_path / "notebooks"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_home_path(self) -> Path:
def generate_app_home_path(self) -> Path:
"""The PrimAITE app home path."""
path = self._dirs.user_data_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_dir_path(self) -> Path:
def generate_app_config_dir_path(self) -> Path:
"""The PrimAITE app config directory path."""
path = self._dirs.user_config_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_file_path(self) -> Path:
def generate_app_config_file_path(self) -> Path:
"""The PrimAITE app config file path."""
return self.app_config_dir_path / "primaite_config.yaml"
@property
def app_log_dir_path(self) -> Path:
def generate_app_log_dir_path(self) -> Path:
"""The PrimAITE app log directory path."""
if sys.platform == "win32":
path = self.app_home_path / "logs"
@@ -100,8 +101,7 @@ class _PrimaitePaths:
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_log_file_path(self) -> Path:
def generate_app_log_file_path(self) -> Path:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"
@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}

View File

@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
from arcd_gate.cli import setup as gate_setup
from primaite import getLogger
from primaite.setup import reset_demo_notebooks, reset_example_configs
@@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Setting up ARCD GATE...")
gate_setup()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
@@ -131,16 +127,10 @@ def session(
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from threading import Thread
from primaite.config.load import example_config_path
from primaite.main import run
from primaite.utils.start_gate_server import start_gate_server
server_thread = Thread(target=start_gate_server)
server_thread.start()
if not config:
config = example_config_path()
print(config)
run(config_path=config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -2,10 +2,17 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 20
n_learn_steps: 128
n_eval_episodes: 20
n_eval_steps: 128
n_learn_episodes: 25
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
@@ -107,7 +114,7 @@ game_config:
- ref: defender
team: BLUE
type: GATERLAgent
type: ProxyAgent
observation_space:
type: UC2BlueObservation

View File

@@ -1,31 +0,0 @@
# flake8: noqa
from typing import Dict, Optional, Tuple
from gymnasium.core import ActType, ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
class GATERLAgent(AbstractGATEAgent):
...
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
# because when we are supporting MARL, the actions form multiple agents will have to be batched
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
# with the actions for those agents.
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(agent_name, action_space, observation_space, reward_function)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
return self.most_recent_action

View File

@@ -4,10 +4,11 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
import numpy as np
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
@@ -55,7 +56,7 @@ class AbstractAgent(ABC):
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings],
) -> None:
@@ -72,21 +73,21 @@ class AbstractAgent(ABC):
:type reward_function: Optional[RewardFunction]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
def convert_state_to_obs(self, state: Dict) -> ObsType:
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
return self.observation_manager.update(state)
def calculate_reward_from_state(self, state: Dict) -> float:
def update_reward(self, state: Dict) -> float:
"""
Use the reward function to calculate a reward from the state.
@@ -95,10 +96,10 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.calculate(state)
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -111,7 +112,7 @@ class AbstractAgent(ABC):
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
@@ -119,7 +120,7 @@ class AbstractAgent(ABC):
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_space.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
@@ -132,7 +133,7 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -142,7 +143,47 @@ class RandomAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
return self.action_manager.get_action(self.action_manager.space.sample())
class ProxyAgent(AbstractAgent):
"""Agent that sends observations to an RL model and receives actions from that model."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""
Store the most recent action taken by the agent.
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action
class DataManipulationAgent(AbstractScriptedAgent):

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
pass
class ObservationSpace:
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -947,15 +948,17 @@ class ObservationSpace:
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.current_observation: ObsType
def observe(self, state: Dict) -> Dict:
def update(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
return self.obs.observe(state)
self.current_observation = self.obs.observe(state)
return self.current_observation
@property
def space(self) -> None:
@@ -963,7 +966,7 @@ class ObservationSpace:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.

View File

@@ -26,7 +26,7 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
class RewardFunction:
"""Manages the reward function for the agent."""
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
@@ -238,6 +238,7 @@ class RewardFunction:
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.
@@ -249,7 +250,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def calculate(self, state: Dict) -> float:
def update(self, state: Dict) -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -260,7 +261,8 @@ class RewardFunction:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
return total
self.current_reward = total
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":

63
src/primaite/game/io.py Normal file
View File

@@ -0,0 +1,63 @@
from datetime import datetime
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, ConfigDict
from primaite import PRIMAITE_PATHS
from primaite.simulator import SIM_OUTPUT
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
model_config = ConfigDict(extra="forbid")
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
checkpoint_interval: int = 10
"""How often to save a checkpoint model (if save_checkpoints is True)."""
save_logs: bool = True
"""Whether to save logs"""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
class SessionIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings: SessionIOSettings = settings
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
# possible refactor needed
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
timestamp = datetime.now()
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def generate_model_save_path(self, agent_name: str) -> Path:
"""Return the path where the final model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_final"
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"

View File

@@ -0,0 +1,3 @@
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]

View File

@@ -0,0 +1,84 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass
@abstractmethod
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(self) -> None:
"""Load agent from a file."""
pass
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -0,0 +1,80 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP
elif algorithm == "A2C":
self._agent_class = A2C
policy = A2C_MLP
else:
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.io_manager.settings.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
reward_data = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
print(reward_data)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,18 +1,21 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from copy import deepcopy
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
from arcd_gate.client.gate_client import ActType, GATEClient
from gymnasium import spaces
import gymnasium
from gymnasium.core import ActType, ObsType
from gymnasium.spaces.utils import flatten, flatten_space
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, RandomAgent
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -34,109 +37,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteGATEClient(GATEClient):
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
"""
Create a new GATE client for PrimAITE.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
:param parent_session: The parent session object.
:type parent_session: PrimaiteSession
:param service_port: The port on which the GATE service is running.
:type service_port: int, optional
"""
super().__init__(service_port=service_port)
self.parent_session: "PrimaiteSession" = parent_session
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
@property
def rl_framework(self) -> str:
"""The reinforcement learning framework to use."""
return self.parent_session.training_options.rl_framework
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
@property
def rl_algorithm(self) -> str:
"""The reinforcement learning algorithm to use."""
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> Optional[int]:
"""The seed to use for the environment's random number generator."""
return self.parent_session.training_options.seed
@property
def n_learn_episodes(self) -> int:
"""The number of episodes in each learning run."""
return self.parent_session.training_options.n_learn_episodes
@property
def n_learn_steps(self) -> int:
"""The number of steps in each learning episode."""
return self.parent_session.training_options.n_learn_steps
@property
def n_eval_episodes(self) -> int:
"""The number of episodes in each evaluation run."""
return self.parent_session.training_options.n_eval_episodes
@property
def n_eval_steps(self) -> int:
"""The number of steps in each evaluation episode."""
return self.parent_session.training_options.n_eval_steps
@property
def action_space(self) -> spaces.Space:
"""The gym action space of the agent."""
return self.parent_session.rl_agent.action_space.space
@property
def observation_space(self) -> spaces.Space:
"""The gymnasium observation space of the agent."""
return flatten_space(self.parent_session.rl_agent.observation_space.space)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
"""Take a step in the environment.
This method is called by GATE to advance the simulation by one timestep.
:param action: The agent's action.
:type action: ActType
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
:rtype: Tuple[ObsType, float, bool, bool, Dict]
"""
self.parent_session.rl_agent.most_recent_action = action
self.parent_session.step()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
rew = self.parent_session.rl_agent.reward_function.calculate(state)
term = False
trunc = False
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ObsType, Dict]:
"""Reset the environment.
return next_obs, reward, terminated, truncated, info
This method is called when the environment is initialized and at the end of each episode.
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
:param seed: The seed to use for the environment's random number generator.
:type seed: int, optional
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
compatibility with GATE.
:type options: dict[str, Any], optional
:return: The initial observation and an empty info dictionary.
:rtype: Tuple[ObsType, Dict]
"""
self.parent_session.reset()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
return obs, {}
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
def close(self):
"""Close the session, this will stop the gate client and close the simulation."""
self.parent_session.close()
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
@@ -146,6 +102,8 @@ class PrimaiteSessionOptions(BaseModel):
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
model_config = ConfigDict(extra="forbid")
ports: List[str]
protocols: List[str]
@@ -153,48 +111,105 @@ class PrimaiteSessionOptions(BaseModel):
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
rl_framework: str
rl_algorithm: str
seed: Optional[int]
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int
n_eval_steps: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
def __init__(self):
"""Initialise a PrimaiteSession object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agent: AbstractAgent
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
self.gate_client.start()
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
@@ -208,57 +223,76 @@ class PrimaiteSession:
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
# 5. primaite session asks agents to produce an action based on most recent state
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
sim_state = self.simulation.describe_state()
agent.update_observation(state)
agent.update_reward(state)
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.convert_state_to_obs(sim_state)
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.calculate_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
_LOGGER.debug("Getting agent action")
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
self.simulation.apply_request(agent_request)
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
self.simulation.apply_timestep(self.step_counter)
def advance_timestep(self) -> None:
"""Advance timestep."""
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.simulation.apply_timestep(self.step_counter)
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
return NotImplemented
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
def close(self) -> None:
"""Close the session, this will stop the gate client and close the simulation."""
"""Close the session, this will stop the env and close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent. Used by GATE.
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
@@ -276,6 +310,11 @@ class PrimaiteSession:
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
net = sim.network
@@ -425,7 +464,7 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
@@ -479,16 +518,15 @@ class PrimaiteSession:
agent_settings=agent_settings,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
new_agent = RandomAgent(
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
agent_settings=agent_settings,
)
sess.agents.append(new_agent)
sess.rl_agent = new_agent
sess.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -501,4 +539,14 @@ class PrimaiteSession:
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
return sess

View File

@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
@@ -31,7 +32,7 @@ def run(
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()

View File

@@ -1,14 +1,26 @@
"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory."""
from datetime import datetime
from pathlib import Path
from primaite import _PRIMAITE_ROOT
SIM_OUTPUT = None
"A path at the repo root dir to use temporarily for sim output testing while in dev."
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
__all__ = ["SIM_OUTPUT"]
if not SIM_OUTPUT:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)
class __SimOutput:
def __init__(self):
self._path: Path = (
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
@property
def path(self) -> Path:
return self._path
@path.setter
def path(self, new_path: Path) -> None:
self._path = new_path
self._path.mkdir(exist_ok=True, parents=True)
SIM_OUTPUT = __SimOutput()

View File

@@ -957,7 +957,7 @@ class Node(SimComponent):
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):

View File

@@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network:
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
# Client 2
client_2 = Computer(

View File

@@ -2,13 +2,14 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from uuid import uuid4
from prettytable import PrettyTable
from primaite import getLogger
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.software_manager import SoftwareManager
_LOGGER = getLogger(__name__)
class DatabaseClient(Application):
"""
@@ -130,7 +131,7 @@ class DatabaseClient(Application):
def execute(self) -> None:
"""Run the DatabaseClient."""
# super().execute()
super().execute()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
@@ -148,21 +149,6 @@ class DatabaseClient(Application):
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
def _print_data(self, data: Dict):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
if data:
table = PrettyTable(list(data.values())[0])
table.align = "l"
table.title = f"{self.sys_log.hostname} Database Client"
for row in data.values():
table.add_row(row.values())
print(table)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
@@ -179,5 +165,5 @@ class DatabaseClient(Application):
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
self._print_data(payload["data"])
_LOGGER.debug(f"Received payload {payload}")
return True

View File

@@ -30,7 +30,7 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.execute()
self.run()
def describe_state(self) -> Dict:
"""
@@ -135,6 +135,3 @@ class WebBrowser(Application):
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
self.latest_response = payload
return True
def execute(self):
pass

View File

@@ -75,7 +75,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = SIM_OUTPUT / self.hostname
root = SIM_OUTPUT.path / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")

View File

@@ -41,7 +41,6 @@ class SysLog:
JSON-like messages.
"""
log_path = self._get_log_path()
file_handler = logging.FileHandler(filename=log_path)
file_handler.setLevel(logging.DEBUG)
@@ -81,7 +80,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT / self.hostname
root = SIM_OUTPUT.path / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -1,10 +1,6 @@
import sqlite3
from datetime import datetime
from ipaddress import IPv4Address
from sqlite3 import OperationalError
from typing import Any, Dict, List, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from typing import Any, Dict, List, Literal, Optional, Union
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -19,7 +15,7 @@ class DatabaseService(Service):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
password: Optional[str] = None
@@ -41,38 +37,6 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
self._connect()
def _connect(self):
self._conn = sqlite3.connect(self._db_file.sim_path)
self._cursor = self._conn.cursor()
def tables(self) -> List[str]:
"""
Get a list of table names present in the database.
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql, None)
if isinstance(results["data"], dict):
return list(results["data"].keys())
return []
def show(self, markdown: bool = False):
"""
Prints a list of table names in the database using PrettyTable.
:param markdown: Whether to output the table in Markdown format.
"""
table = PrettyTable(["Table"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.file_system.sys_log.hostname} Database"
for row in self.tables():
table.add_row([row])
print(table)
def configure_backup(self, backup_server: IPv4Address):
"""
@@ -89,8 +53,6 @@ class DatabaseService(Service):
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
self._conn.close()
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
@@ -98,12 +60,10 @@ class DatabaseService(Service):
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server,
src_file_name=self._db_file.name,
src_folder_name=self._db_file.folder.name,
src_folder_name=self.folder.name,
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
real_file_path=self._db_file.sim_path,
)
self._connect()
if response:
return True
@@ -125,25 +85,29 @@ class DatabaseService(Service):
dest_ip_address=self.backup_server,
)
if response:
self._conn.close()
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self._connect()
if not response:
self.sys_log.error("Unable to restore database backup.")
return False
return self._db_file is not None
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self.sys_log.error("Unable to restore database backup.")
return False
if self._db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
self.set_health_state(SoftwareHealthState.GOOD)
return True
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
self.folder = self._db_file.folder
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
@@ -163,31 +127,32 @@ class DatabaseService(Service):
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
Possible queries:
- SELECT : returns the data
- DELETE : deletes the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
try:
self._cursor.execute(query)
self._conn.commit()
except OperationalError:
# Handle the case where the table does not exist.
self.sys_log.error(f"{self.name}: Error, query failed")
return {"status_code": 404, "data": {}}
data = []
description = self._cursor.description
if description:
headers = []
for header in description:
headers.append(header[0])
data = self._cursor.fetchall()
if data and headers:
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
if query == "SELECT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
if self.health_state_actual == SoftwareHealthState.GOOD:
self.health_state_actual = SoftwareHealthState.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
else:
# Invalid query
return {"status_code": 500, "data": False}
def describe_state(self) -> Dict:
"""

View File

@@ -106,7 +106,7 @@ class WebServer(Service):
# get data from DatabaseServer
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
# get all users
if db_client.query("SELECT * FROM user;"):
if db_client.query("SELECT"):
# query succeeded
response.status_code = HttpStatusCode.OK

View File

@@ -1,12 +0,0 @@
"""Utility script to start the gate server for running PrimAITE in attached mode."""
from arcd_gate.server.gate_service import GATEService
def start_gate_server():
"""Start the gate server."""
service = GATEService()
service.start()
if __name__ == "__main__":
start_gate_server()