Add GATE wheel temporarily

This commit is contained in:
Marek Wolan
2023-10-10 09:48:04 +01:00
parent 91f06c15f6
commit b53c3856dd
4 changed files with 149 additions and 101 deletions

View File

@@ -1,8 +1,12 @@
training_config:
rl_framework: SB3
rl_algo: PPO
rl_algorithm: PPO
seed: 333
n_learn_episodes: 2
n_learn_steps: 128
n_learn_episodes: 1000
n_eval_episodes: 2
n_eval_steps: 128
game_config:
ports:

View File

@@ -1,5 +1,64 @@
from primaite.game.agent.interface import AbstractGATEAgent
from arcd_gate.client.gate_client import GATEClient
class GATEMan(GATEClient):
@property
def rl_framework(self) -> str:
return "SB3"
@property
def rl_framework(self) -> str:
pass
@property
def rl_algorithm(self) -> str:
pass
@property
def seed(self) -> Optional[int]:
return None
@property
def n_learn_episodes(self) -> int:
return 0
@property
def n_learn_steps(self) -> int:
return 0
@property
def n_eval_episodes(self) -> int:
return 0
@property
def n_eval_steps(self) -> int:
return 0
@property
def action_space(self) -> spaces.Space:
pass
@property
def observation_space(self) -> spaces.Space:
pass
def step(self, action: ActType) -> Tuple[np.ndarray, float, bool, bool, Dict]:
pass
def reset(self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None) -> Tuple[np.ndarray, Dict]:
pass
def close(self):
pass
class GATERLAgent(AbstractGATEAgent):
...
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
# because when we are supporting MARL, the actions form multiple agents will have to be batched
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
# with the actions for those agents.

View File

@@ -6,7 +6,9 @@
# 5. idk
from ipaddress import IPv4Address
from typing import Dict, List
from typing import Any, Dict, List, Optional, Tuple
from gymnasium.vector.utils import spaces
import numpy as np
from pydantic import BaseModel
@@ -44,26 +46,101 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from arcd_gate.client.gate_client import GATEClient, ActType
from numpy import ndarray
_LOGGER = getLogger(__name__)
class PrimaiteGATEClient(GATEClient):
def __init__(self, parent_session:"PrimaiteSession", service_port: int = 50000):
super().__init__(service_port=service_port)
self.parent_session:"PrimaiteSession"
@property
def rl_framework(self) -> str:
return self.parent_session.training_options.rl_framework
@property
def rl_algorithm(self) -> str:
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> int | None:
return self.parent_session.training_options.seed
@property
def n_learn_episodes(self) -> int:
return self.parent_session.training_options.n_learn_episodes
@property
def n_learn_steps(self) -> int:
return self.parent_session.training_options.n_learn_steps
@property
def n_eval_episodes(self) -> int:
return self.parent_session.training_options.n_eval_episodes
@property
def n_eval_steps(self) -> int:
return self.parent_session.training_options.n_eval_steps
@property
def action_space(self) -> spaces.Space:
return self.parent_session.rl_agent.action_space
@property
def observation_space(self) -> spaces.Space:
return self.parent_session.rl_agent.observation_space
def step(self, action: ActType) -> Tuple[ndarray, float, bool, bool, Dict]:
self.parent_session.step()
#TODO: not sure how to go about this.
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ndarray, Dict]:
...
def close(self):
...
class PrimaiteSessionOptions(BaseModel):
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
rl_framework:str
rl_algorithm:str
seed:Optional[int]
n_learn_episodes:int
n_learn_steps:int
n_eval_episodes:int
n_eval_steps:int
class PrimaiteSession:
def __init__(self):
self.simulation: Simulation = Simulation()
self.agents: List[AbstractAgent] = []
self.rl_agent: AbstractAgent
# which of the agents should be used for sending RL data to GATE client?
self.step_counter: int = 0
self.episode_counter: int = 0
self.options: PrimaiteSessionOptions
self.training_options: TrainingOptions
self.ref_map_nodes: Dict[str, Node] = {}
self.ref_map_services: Dict[str, Service] = {}
self.ref_map_links: Dict[str, Link] = {}
def start_session(self, opts="TODO..."):
"""Commence the session, this gives the gate client control over the simulation/agent loop."""
...
def eval(self, opts="TODO..."):
...
def step(self):
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
@@ -108,6 +185,7 @@ class PrimaiteSession:
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg['training_config'])
sim = sess.simulation
net = sim.network
@@ -231,104 +309,6 @@ class PrimaiteSession:
# CREATE OBSERVATION SPACE
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
"""
# if observation_space_cfg is None:
# obs_space = NullObservation()
# elif observation_space_cfg["type"] == "UC2BlueObservation":
# node_obs_list = []
# link_obs_list = []
# # node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses
# node_ip_to_index = {}
# for node_idx, node_cfg in enumerate(nodes_cfg):
# n_ref = node_cfg["ref"]
# n_obj = net.nodes[ref_map_nodes[n_ref]]
# for nic_uuid, nic_obj in n_obj.nics.items():
# node_ip_to_index[nic_obj.ip_address] = node_idx + 2
# for node_obs_cfg in observation_space_cfg["options"]["nodes"]:
# node_ref = node_obs_cfg["node_ref"]
# folder_obs_list = []
# service_obs_list = []
# if "services" in node_obs_cfg:
# for service_obs_cfg in node_obs_cfg["services"]:
# service_obs_list.append(
# ServiceObservation(
# where=[
# "network",
# "nodes",
# ref_map_nodes[node_ref],
# "services",
# ref_map_services[service_obs_cfg["service_ref"]],
# ]
# )
# )
# if "folders" in node_obs_cfg:
# for folder_obs_cfg in node_obs_cfg["folders"]:
# file_obs_list = []
# if "files" in folder_obs_cfg:
# for file_obs_cfg in folder_obs_cfg["files"]:
# file_obs_list.append(
# FileObservation(
# where=[
# "network",
# "nodes",
# ref_map_nodes[node_ref],
# "folders",
# folder_obs_cfg["folder_name"],
# "files",
# file_obs_cfg["file_name"],
# ]
# )
# )
# folder_obs_list.append(
# FolderObservation(
# where=[
# "network",
# "nodes",
# ref_map_nodes[node_ref],
# "folders",
# folder_obs_cfg["folder_name"],
# ],
# files=file_obs_list,
# )
# )
# nic_obs_list = []
# for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg["node_ref"]]].nics.keys():
# nic_obs_list.append(
# NicObservation(where=["network", "nodes", ref_map_nodes[node_ref], "NICs", nic_uuid])
# )
# node_obs_list.append(
# NodeObservation(
# where=["network", "nodes", ref_map_nodes[node_ref]],
# services=service_obs_list,
# folders=folder_obs_list,
# nics=nic_obs_list,
# logon_status=False,
# )
# )
# for link_obs_cfg in observation_space_cfg["options"]["links"]:
# link_ref = link_obs_cfg["link_ref"]
# link_obs_list.append(LinkObservation(where=["network", "links", ref_map_links[link_ref]]))
# acl_obs = AclObservation(
# node_ip_to_id=node_ip_to_index,
# ports=game_cfg["ports"],
# protocols=game_cfg["ports"],
# where=["network", "nodes", observation_space_cfg["options"]["acl"]["router_node_ref"]],
# )
# obs_space = UC2BlueObservation(
# nodes=node_obs_list, links=link_obs_list, acl=acl_obs, ics=ICSObservation()
# )
# elif observation_space_cfg["type"] == "UC2RedObservation":
# obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim)
# elif observation_space_cfg["type"] == "UC2GreenObservation":
# obs_space = UC2GreenObservation.from_config(observation_space_cfg.get('options',{}))
# else:
# print("observation space config not specified correctly.")
# obs_space = NullObservation()
"""
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs

View File

@@ -0,0 +1,5 @@
"""Utility script to start the gate server for running PrimAITE in attached mode."""
from arcd_gate.server.gate_service import GATEService
service = GATEService()
service.start()