diff --git a/.gitignore b/.gitignore index ef842c6e..892751d9 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ simulation_output/ benchmark/output # src/primaite/notebooks/scratch.ipynb src/primaite/notebooks/scratch.py +sandbox.py diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index b68861e1..7d5b50d6 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -655,8 +655,8 @@ simulation: - ref: data_manipulation_bot type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 services: diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 9450c419..b811bfa5 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -1,14 +1,10 @@ training_config: - rl_framework: RLLIB_single_agent - rl_algorithm: PPO - seed: 333 - n_learn_episodes: 1 - n_eval_episodes: 5 - max_steps_per_episode: 256 - deterministic_eval: false - n_agents: 1 + rl_framework: RLLIB_multi_agent + # rl_framework: SB3 + n_agents: 2 agent_references: - - defender + - defender_1 + - defender_2 io_settings: save_checkpoints: true @@ -36,31 +32,26 @@ agents: action_space: action_list: - type: DONOTHING - # - # - type: NODE_LOGON - # - type: NODE_LOGOFF - # - type: NODE_APPLICATION_EXECUTE - # options: - # execution_definition: - # target_address: arcd.com - + - type: NODE_APPLICATION_EXECUTE options: nodes: - node_ref: client_2 + applications: + - application_ref: client_2_web_browser max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_nics_per_node: 2 - max_acl_rules: 10 + max_applications_per_node: 1 reward_function: reward_components: - type: DUMMY agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + start_settings: + start_step: 5 + frequency: 4 + variance: 3 - ref: client_1_data_manipulation_red_bot team: RED @@ -69,38 +60,20 @@ agents: observation_space: type: UC2RedObservation options: - nodes: - - node_ref: client_1 - observations: - - logon_status - - operating_status - services: - - service_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} + nodes: {} action_space: action_list: - type: DONOTHING - # None: """Add a reward component to the reward function. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 38e9d5fc..a36cbea9 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -125,6 +125,7 @@ class PrimaiteGame: for agent in self.agents: agent.update_observation(state) agent.update_reward(state) + agent.reward_function.total_reward += agent.reward_function.current_reward def apply_agent_actions(self) -> None: """Apply all actions to simulation as requests.""" @@ -155,6 +156,8 @@ class PrimaiteGame: self.step_counter = 0 _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") self.simulation.reset_component_for_episode(episode=self.episode_counter) + for agent in self.agents: + agent.reward_function.total_reward = 0.0 def close(self) -> None: """Close the game, this will close the simulation.""" @@ -240,7 +243,7 @@ class PrimaiteGame: position=r_num, ) else: - print("invalid node type") + _LOGGER.warning(f"invalid node type {n_type} in config") if "services" in node_cfg: for service_cfg in node_cfg["services"]: new_service = None @@ -256,12 +259,12 @@ class PrimaiteGame: "FTPServer": FTPServer, } if service_type in service_types_mapping: - print(f"installing {service_type} on node {new_node.hostname}") + _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") new_node.software_manager.install(service_types_mapping[service_type]) new_service = new_node.software_manager.software[service_type] game.ref_map_services[service_ref] = new_service.uuid else: - print(f"service type not found {service_type}") + _LOGGER.warning(f"service type not found {service_type}") # service-dependent options if service_type == "DatabaseClient": if "options" in service_cfg: @@ -295,7 +298,7 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: - print(f"application type not found {application_type}") + _LOGGER.warning(f"application type not found {application_type}") if application_type == "DataManipulationBot": if "options" in application_cfg: @@ -416,7 +419,7 @@ class PrimaiteGame: ) game.agents.append(new_agent) else: - print("agent type not found") + _LOGGER.warning(f"agent type {agent_type} not found") game.simulation.set_original_state() diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index d31d53cc..0d4b6d0e 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -1,5 +1,21 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train a Multi agent system using RLLIB\n", + "\n", + "This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### First, Import packages and read our config file." + ] + }, { "cell_type": "code", "execution_count": null, @@ -8,75 +24,56 @@ "source": [ "from primaite.game.game import PrimaiteGame\n", "import yaml\n", - "from primaite.config.load import example_config_path\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(example_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n", + "from primaite.session.environment import PrimaiteRayEnv\n", + "from primaite import PRIMAITE_PATHS\n", "\n", - "game = PrimaiteGame.from_config(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# gym = PrimaiteRayEnv({\"game\":game})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "import ray\n", "from ray import air, tune\n", - "from ray.rllib.algorithms.ppo import PPOConfig" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ray.shutdown()\n", - "ray.init()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "from ray.rllib.algorithms.ppo import PPOConfig\n", "from primaite.session.environment import PrimaiteRayMARLEnv\n", "\n", + "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", + "# to copy the files to your user data path.\n", + "with open(PRIMAITE_PATHS.user_config_path / 'example_config/example_config_2_rl_agents.yaml', 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", "\n", - "env_config = {\"game\":game}\n", + "ray.init(local_mode=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a Ray algorithm config which accepts our two agents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "config = (\n", " PPOConfig()\n", - " .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n", - " .rollouts(num_rollout_workers=0)\n", " .multi_agent(\n", - " policies={agent.agent_name for agent in game.rl_agents},\n", + " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", + " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n", + " .rollouts(num_rollout_workers=0)\n", " .training(train_batch_size=128)\n", " )\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Set training parameters and start the training\n", + "This example will save outputs to a default Ray directory and use mostly default settings." + ] + }, { "cell_type": "code", "execution_count": null, @@ -86,21 +83,11 @@ "tune.Tuner(\n", " \"PPO\",\n", " run_config=air.RunConfig(\n", - " stop={\"training_iteration\": 128},\n", - " checkpoint_config=air.CheckpointConfig(\n", - " checkpoint_frequency=10,\n", - " ),\n", + " stop={\"timesteps_total\": 512},\n", " ),\n", " param_space=config\n", ").fit()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index 8ee16d41..a89b29e4 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -1,5 +1,13 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train a Single agent system using RLLib\n", + "This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent." + ] + }, { "cell_type": "code", "execution_count": null, @@ -10,19 +18,25 @@ "import yaml\n", "from primaite.config.load import example_config_path\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "from primaite.session.environment import PrimaiteRayEnv\n", + "from ray.rllib.algorithms import ppo\n", + "from ray import air, tune\n", + "import ray\n", + "from ray.rllib.algorithms.ppo import PPOConfig\n", + "\n", + "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", + "# to copy the files to your user data path.\n", "with open(example_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", - "game = PrimaiteGame.from_config(cfg)" + "ray.init(local_mode=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a Ray algorithm and pass it our config." ] }, { @@ -31,7 +45,21 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteRayEnv({\"game\":game})" + "env_config = {\"cfg\":cfg}\n", + "\n", + "config = (\n", + " PPOConfig()\n", + " .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n", + " .rollouts(num_rollout_workers=0)\n", + " .training(train_batch_size=128)\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Set training parameters and start the training" ] }, { @@ -40,61 +68,13 @@ "metadata": {}, "outputs": [], "source": [ - "import ray\n", - "from ray.rllib.algorithms import ppo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ray.shutdown()\n", - "ray.init()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "env_config = {\"game\":game}\n", - "config = {\n", - " \"env\" : PrimaiteRayEnv,\n", - " \"env_config\" : env_config,\n", - " \"disable_env_checking\": True,\n", - " \"num_rollout_workers\": 0,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "algo = ppo.PPO(config=config)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(5):\n", - " result = algo.train()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "algo.save(\"temp/deleteme\")" + "tune.Tuner(\n", + " \"PPO\",\n", + " run_config=air.RunConfig(\n", + " stop={\"timesteps_total\": 512}\n", + " ),\n", + " param_space=config\n", + ").fit()\n" ] } ], diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a5fdade9..c2f19f36 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -37,11 +37,14 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} - print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}") return next_obs, reward, terminated, truncated, info def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" + print( + f"Resetting environment, episode {self.game.episode_counter}, " + f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" + ) self.game.reset() state = self.game.get_sim_state() self.game.update_agents(state) @@ -69,14 +72,15 @@ class PrimaiteGymEnv(gymnasium.Env): class PrimaiteRayEnv(gymnasium.Env): """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" - def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None: + def __init__(self, env_config: Dict) -> None: """Initialise the environment. :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. :type env_config: Dict[str, PrimaiteGame] """ - self.env = PrimaiteGymEnv(game=env_config["game"]) + self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) + self.env.game.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -92,14 +96,14 @@ class PrimaiteRayEnv(gymnasium.Env): class PrimaiteRayMARLEnv(MultiAgentEnv): """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" - def __init__(self, env_config: Optional[Dict] = None) -> None: + def __init__(self, env_config: Dict) -> None: """Initialise the environment. :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. :type env_config: Dict[str, PrimaiteGame] """ - self.game: PrimaiteGame = env_config["game"] + self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"]) """Reference to the primaite game""" self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} """List of all possible agents in the environment. This list should not change!""" @@ -108,7 +112,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.terminateds = set() self.truncateds = set() self.observation_space = gymnasium.spaces.Dict( - {name: agent.observation_manager.space for name, agent in self.agents.items()} + { + name: gymnasium.spaces.flatten_space(agent.observation_manager.space) + for name, agent in self.agents.items() + } ) self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} @@ -159,4 +166,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" - return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()} + obs = {} + for name, agent in self.agents.items(): + unflat_space = agent.observation_manager.space + unflat_obs = agent.observation_manager.current_observation + obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + return obs diff --git a/src/primaite/session/policy/rllib.py b/src/primaite/session/policy/rllib.py index be181797..ca69a2a8 100644 --- a/src/primaite/session/policy/rllib.py +++ b/src/primaite/session/policy/rllib.py @@ -12,6 +12,10 @@ from ray import air, tune from ray.rllib.algorithms import ppo from ray.rllib.algorithms.ppo import PPOConfig +from primaite import getLogger + +_LOGGER = getLogger(__name__) + class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): """Single agent RL policy using Ray RLLib.""" @@ -19,7 +23,7 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) - config = { + self.config = { "env": PrimaiteRayEnv, "env_config": {"game": session.game}, "disable_env_checking": True, @@ -29,12 +33,13 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): ray.shutdown() ray.init() - self._algo = ppo.PPO(config=config) - def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - for ep in range(n_episodes): - self._algo.train() + self.config["training_iterations"] = n_episodes * timesteps_per_episode + self.config["train_batch_size"] = 128 + self._algo = ppo.PPO(config=self.config) + _LOGGER.info("Starting RLLIB training session") + self._algo.train() def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" diff --git a/src/primaite/session/policy/sb3.py b/src/primaite/session/policy/sb3.py index 051e2770..254baf4d 100644 --- a/src/primaite/session/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -51,14 +51,13 @@ class SB3Policy(PolicyABC, identifier="SB3"): def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" - reward_data = evaluate_policy( + _ = evaluate_policy( self._agent, self.session.env, n_eval_episodes=n_episodes, deterministic=deterministic, return_episode_rewards=True, ) - print(reward_data) def save(self, save_path: Path) -> None: """ diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 3919902a..ef462d83 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -62,6 +62,7 @@ class PrimaiteSession: def start_session(self) -> None: """Commence the training/eval session.""" + print("Starting Primaite Session") self.mode = SessionMode.TRAIN n_learn_episodes = self.training_options.n_learn_episodes n_eval_episodes = self.training_options.n_eval_episodes diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 5e1953e2..98a7e8db 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -113,7 +113,7 @@ class RequestManager(BaseModel): """ if name in self.request_types: msg = f"Overwriting request type {name}." - _LOGGER.warning(msg) + _LOGGER.debug(msg) self.request_types[name] = request_type diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 97b62f95..e1780448 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -220,7 +220,7 @@ class Network(SimComponent): self._node_id_map[len(self.nodes)] = node node.parent = self self._nx_graph.add_node(node.hostname) - _LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}") + _LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}") self._node_request_manager.add_request(name=node.uuid, request_type=RequestType(func=node._request_manager)) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 04c76c6b..a310a3f5 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -181,13 +181,13 @@ class NIC(SimComponent): if self.enabled: return if not self._connected_node: - _LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node") + _LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Node") return if self._connected_node.operating_state != NodeOperatingState.ON: self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on") return if not self._connected_link: - _LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link") + _LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Link") return self.enabled = True diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 6a7c80ca..61cf1560 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -56,7 +56,7 @@ class DatabaseService(Service): def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - print("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") + _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") self.connections.clear() super().reset_component_for_episode(episode) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index e63b875a..afd6cb74 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -47,7 +47,6 @@ class WebServer(Service): state["last_response_status_code"] = ( self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None ) - print(state) return state def __init__(self, **kwargs):