From b53c3856dd470e88fb18d8635758d9374398dfea Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 10 Oct 2023 09:48:04 +0100 Subject: [PATCH] Add GATE wheel temporarily --- example_config.yaml | 8 +- src/primaite/game/agent/GATE_agents.py | 59 ++++++++ src/primaite/game/session.py | 178 +++++++++++------------- src/primaite/utils/start_gate_server.py | 5 + 4 files changed, 149 insertions(+), 101 deletions(-) create mode 100644 src/primaite/utils/start_gate_server.py diff --git a/example_config.yaml b/example_config.yaml index f7faf589..afdf1b0a 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -1,8 +1,12 @@ training_config: rl_framework: SB3 - rl_algo: PPO + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 2 n_learn_steps: 128 - n_learn_episodes: 1000 + n_eval_episodes: 2 + n_eval_steps: 128 + game_config: ports: diff --git a/src/primaite/game/agent/GATE_agents.py b/src/primaite/game/agent/GATE_agents.py index 5bdfebe4..ac1d776b 100644 --- a/src/primaite/game/agent/GATE_agents.py +++ b/src/primaite/game/agent/GATE_agents.py @@ -1,5 +1,64 @@ from primaite.game.agent.interface import AbstractGATEAgent +from arcd_gate.client.gate_client import GATEClient +class GATEMan(GATEClient): + + @property + def rl_framework(self) -> str: + return "SB3" + + @property + def rl_framework(self) -> str: + pass + + @property + def rl_algorithm(self) -> str: + pass + + @property + def seed(self) -> Optional[int]: + return None + + @property + def n_learn_episodes(self) -> int: + return 0 + + @property + def n_learn_steps(self) -> int: + return 0 + + @property + def n_eval_episodes(self) -> int: + return 0 + + @property + def n_eval_steps(self) -> int: + return 0 + + @property + def action_space(self) -> spaces.Space: + pass + + @property + def observation_space(self) -> spaces.Space: + pass + + def step(self, action: ActType) -> Tuple[np.ndarray, float, bool, bool, Dict]: + pass + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None) -> Tuple[np.ndarray, Dict]: + pass + + def close(self): + pass + class GATERLAgent(AbstractGATEAgent): ... + # The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents, + # because when we are supporting MARL, the actions form multiple agents will have to be batched + + # For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply + # with the actions for those agents. + + diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 7b2225ef..7746f78c 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -6,7 +6,9 @@ # 5. idk from ipaddress import IPv4Address -from typing import Dict, List +from typing import Any, Dict, List, Optional, Tuple +from gymnasium.vector.utils import spaces +import numpy as np from pydantic import BaseModel @@ -44,26 +46,101 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot from primaite.simulator.system.services.service import Service +from arcd_gate.client.gate_client import GATEClient, ActType +from numpy import ndarray + _LOGGER = getLogger(__name__) +class PrimaiteGATEClient(GATEClient): + def __init__(self, parent_session:"PrimaiteSession", service_port: int = 50000): + super().__init__(service_port=service_port) + self.parent_session:"PrimaiteSession" + + @property + def rl_framework(self) -> str: + return self.parent_session.training_options.rl_framework + + @property + def rl_algorithm(self) -> str: + return self.parent_session.training_options.rl_algorithm + + @property + def seed(self) -> int | None: + return self.parent_session.training_options.seed + + @property + def n_learn_episodes(self) -> int: + return self.parent_session.training_options.n_learn_episodes + + @property + def n_learn_steps(self) -> int: + return self.parent_session.training_options.n_learn_steps + + @property + def n_eval_episodes(self) -> int: + return self.parent_session.training_options.n_eval_episodes + + @property + def n_eval_steps(self) -> int: + return self.parent_session.training_options.n_eval_steps + + @property + def action_space(self) -> spaces.Space: + return self.parent_session.rl_agent.action_space + + @property + def observation_space(self) -> spaces.Space: + return self.parent_session.rl_agent.observation_space + + def step(self, action: ActType) -> Tuple[ndarray, float, bool, bool, Dict]: + self.parent_session.step() + #TODO: not sure how to go about this. + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ndarray, Dict]: + ... + + def close(self): + ... class PrimaiteSessionOptions(BaseModel): ports: List[str] protocols: List[str] +class TrainingOptions(BaseModel): + rl_framework:str + rl_algorithm:str + seed:Optional[int] + n_learn_episodes:int + n_learn_steps:int + n_eval_episodes:int + n_eval_steps:int + + class PrimaiteSession: def __init__(self): self.simulation: Simulation = Simulation() self.agents: List[AbstractAgent] = [] + self.rl_agent: AbstractAgent + # which of the agents should be used for sending RL data to GATE client? self.step_counter: int = 0 self.episode_counter: int = 0 self.options: PrimaiteSessionOptions + self.training_options: TrainingOptions self.ref_map_nodes: Dict[str, Node] = {} self.ref_map_services: Dict[str, Service] = {} self.ref_map_links: Dict[str, Link] = {} + def start_session(self, opts="TODO..."): + """Commence the session, this gives the gate client control over the simulation/agent loop.""" + ... + + def eval(self, opts="TODO..."): + ... + + + def step(self): _LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}") # currently designed with assumption that all agents act once per step in order @@ -108,6 +185,7 @@ class PrimaiteSession: ports=cfg["game_config"]["ports"], protocols=cfg["game_config"]["protocols"], ) + sess.training_options = TrainingOptions(**cfg['training_config']) sim = sess.simulation net = sim.network @@ -231,104 +309,6 @@ class PrimaiteSession: # CREATE OBSERVATION SPACE obs_space = ObservationSpace.from_config(observation_space_cfg, sess) - """ - # if observation_space_cfg is None: - # obs_space = NullObservation() - # elif observation_space_cfg["type"] == "UC2BlueObservation": - # node_obs_list = [] - # link_obs_list = [] - - # # node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses - # node_ip_to_index = {} - # for node_idx, node_cfg in enumerate(nodes_cfg): - # n_ref = node_cfg["ref"] - # n_obj = net.nodes[ref_map_nodes[n_ref]] - # for nic_uuid, nic_obj in n_obj.nics.items(): - # node_ip_to_index[nic_obj.ip_address] = node_idx + 2 - - # for node_obs_cfg in observation_space_cfg["options"]["nodes"]: - # node_ref = node_obs_cfg["node_ref"] - # folder_obs_list = [] - # service_obs_list = [] - # if "services" in node_obs_cfg: - # for service_obs_cfg in node_obs_cfg["services"]: - # service_obs_list.append( - # ServiceObservation( - # where=[ - # "network", - # "nodes", - # ref_map_nodes[node_ref], - # "services", - # ref_map_services[service_obs_cfg["service_ref"]], - # ] - # ) - # ) - # if "folders" in node_obs_cfg: - # for folder_obs_cfg in node_obs_cfg["folders"]: - # file_obs_list = [] - # if "files" in folder_obs_cfg: - # for file_obs_cfg in folder_obs_cfg["files"]: - # file_obs_list.append( - # FileObservation( - # where=[ - # "network", - # "nodes", - # ref_map_nodes[node_ref], - # "folders", - # folder_obs_cfg["folder_name"], - # "files", - # file_obs_cfg["file_name"], - # ] - # ) - # ) - # folder_obs_list.append( - # FolderObservation( - # where=[ - # "network", - # "nodes", - # ref_map_nodes[node_ref], - # "folders", - # folder_obs_cfg["folder_name"], - # ], - # files=file_obs_list, - # ) - # ) - # nic_obs_list = [] - # for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg["node_ref"]]].nics.keys(): - # nic_obs_list.append( - # NicObservation(where=["network", "nodes", ref_map_nodes[node_ref], "NICs", nic_uuid]) - # ) - # node_obs_list.append( - # NodeObservation( - # where=["network", "nodes", ref_map_nodes[node_ref]], - # services=service_obs_list, - # folders=folder_obs_list, - # nics=nic_obs_list, - # logon_status=False, - # ) - # ) - # for link_obs_cfg in observation_space_cfg["options"]["links"]: - # link_ref = link_obs_cfg["link_ref"] - # link_obs_list.append(LinkObservation(where=["network", "links", ref_map_links[link_ref]])) - - # acl_obs = AclObservation( - # node_ip_to_id=node_ip_to_index, - # ports=game_cfg["ports"], - # protocols=game_cfg["ports"], - # where=["network", "nodes", observation_space_cfg["options"]["acl"]["router_node_ref"]], - # ) - # obs_space = UC2BlueObservation( - # nodes=node_obs_list, links=link_obs_list, acl=acl_obs, ics=ICSObservation() - # ) - # elif observation_space_cfg["type"] == "UC2RedObservation": - # obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim) - # elif observation_space_cfg["type"] == "UC2GreenObservation": - # obs_space = UC2GreenObservation.from_config(observation_space_cfg.get('options',{})) - # else: - # print("observation space config not specified correctly.") - # obs_space = NullObservation() - """ - # CREATE ACTION SPACE action_space_cfg["options"]["node_uuids"] = [] # if a list of nodes is defined, convert them from node references to node UUIDs diff --git a/src/primaite/utils/start_gate_server.py b/src/primaite/utils/start_gate_server.py new file mode 100644 index 00000000..53508cd2 --- /dev/null +++ b/src/primaite/utils/start_gate_server.py @@ -0,0 +1,5 @@ +"""Utility script to start the gate server for running PrimAITE in attached mode.""" +from arcd_gate.server.gate_service import GATEService + +service = GATEService() +service.start()