Add GATE wheel temporarily
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algo: PPO
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 2
|
||||
n_learn_steps: 128
|
||||
n_learn_episodes: 1000
|
||||
n_eval_episodes: 2
|
||||
n_eval_steps: 128
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
|
||||
@@ -1,5 +1,64 @@
|
||||
from primaite.game.agent.interface import AbstractGATEAgent
|
||||
from arcd_gate.client.gate_client import GATEClient
|
||||
|
||||
|
||||
class GATEMan(GATEClient):
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
return "SB3"
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
def seed(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
def step(self, action: ActType) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
||||
pass
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None) -> Tuple[np.ndarray, Dict]:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
class GATERLAgent(AbstractGATEAgent):
|
||||
...
|
||||
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
|
||||
# because when we are supporting MARL, the actions form multiple agents will have to be batched
|
||||
|
||||
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
|
||||
# with the actions for those agents.
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
# 5. idk
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from gymnasium.vector.utils import spaces
|
||||
import numpy as np
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -44,26 +46,101 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
from arcd_gate.client.gate_client import GATEClient, ActType
|
||||
from numpy import ndarray
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
def __init__(self, parent_session:"PrimaiteSession", service_port: int = 50000):
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session:"PrimaiteSession"
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
return self.parent_session.training_options.rl_framework
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> int | None:
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
return self.parent_session.rl_agent.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
return self.parent_session.rl_agent.observation_space
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ndarray, float, bool, bool, Dict]:
|
||||
self.parent_session.step()
|
||||
#TODO: not sure how to go about this.
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ndarray, Dict]:
|
||||
...
|
||||
|
||||
def close(self):
|
||||
...
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
rl_framework:str
|
||||
rl_algorithm:str
|
||||
seed:Optional[int]
|
||||
n_learn_episodes:int
|
||||
n_learn_steps:int
|
||||
n_eval_episodes:int
|
||||
n_eval_steps:int
|
||||
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
self.agents: List[AbstractAgent] = []
|
||||
self.rl_agent: AbstractAgent
|
||||
# which of the agents should be used for sending RL data to GATE client?
|
||||
self.step_counter: int = 0
|
||||
self.episode_counter: int = 0
|
||||
self.options: PrimaiteSessionOptions
|
||||
self.training_options: TrainingOptions
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
|
||||
def start_session(self, opts="TODO..."):
|
||||
"""Commence the session, this gives the gate client control over the simulation/agent loop."""
|
||||
...
|
||||
|
||||
def eval(self, opts="TODO..."):
|
||||
...
|
||||
|
||||
|
||||
|
||||
def step(self):
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
@@ -108,6 +185,7 @@ class PrimaiteSession:
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg['training_config'])
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
@@ -231,104 +309,6 @@ class PrimaiteSession:
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
|
||||
"""
|
||||
# if observation_space_cfg is None:
|
||||
# obs_space = NullObservation()
|
||||
# elif observation_space_cfg["type"] == "UC2BlueObservation":
|
||||
# node_obs_list = []
|
||||
# link_obs_list = []
|
||||
|
||||
# # node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses
|
||||
# node_ip_to_index = {}
|
||||
# for node_idx, node_cfg in enumerate(nodes_cfg):
|
||||
# n_ref = node_cfg["ref"]
|
||||
# n_obj = net.nodes[ref_map_nodes[n_ref]]
|
||||
# for nic_uuid, nic_obj in n_obj.nics.items():
|
||||
# node_ip_to_index[nic_obj.ip_address] = node_idx + 2
|
||||
|
||||
# for node_obs_cfg in observation_space_cfg["options"]["nodes"]:
|
||||
# node_ref = node_obs_cfg["node_ref"]
|
||||
# folder_obs_list = []
|
||||
# service_obs_list = []
|
||||
# if "services" in node_obs_cfg:
|
||||
# for service_obs_cfg in node_obs_cfg["services"]:
|
||||
# service_obs_list.append(
|
||||
# ServiceObservation(
|
||||
# where=[
|
||||
# "network",
|
||||
# "nodes",
|
||||
# ref_map_nodes[node_ref],
|
||||
# "services",
|
||||
# ref_map_services[service_obs_cfg["service_ref"]],
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
# if "folders" in node_obs_cfg:
|
||||
# for folder_obs_cfg in node_obs_cfg["folders"]:
|
||||
# file_obs_list = []
|
||||
# if "files" in folder_obs_cfg:
|
||||
# for file_obs_cfg in folder_obs_cfg["files"]:
|
||||
# file_obs_list.append(
|
||||
# FileObservation(
|
||||
# where=[
|
||||
# "network",
|
||||
# "nodes",
|
||||
# ref_map_nodes[node_ref],
|
||||
# "folders",
|
||||
# folder_obs_cfg["folder_name"],
|
||||
# "files",
|
||||
# file_obs_cfg["file_name"],
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
# folder_obs_list.append(
|
||||
# FolderObservation(
|
||||
# where=[
|
||||
# "network",
|
||||
# "nodes",
|
||||
# ref_map_nodes[node_ref],
|
||||
# "folders",
|
||||
# folder_obs_cfg["folder_name"],
|
||||
# ],
|
||||
# files=file_obs_list,
|
||||
# )
|
||||
# )
|
||||
# nic_obs_list = []
|
||||
# for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg["node_ref"]]].nics.keys():
|
||||
# nic_obs_list.append(
|
||||
# NicObservation(where=["network", "nodes", ref_map_nodes[node_ref], "NICs", nic_uuid])
|
||||
# )
|
||||
# node_obs_list.append(
|
||||
# NodeObservation(
|
||||
# where=["network", "nodes", ref_map_nodes[node_ref]],
|
||||
# services=service_obs_list,
|
||||
# folders=folder_obs_list,
|
||||
# nics=nic_obs_list,
|
||||
# logon_status=False,
|
||||
# )
|
||||
# )
|
||||
# for link_obs_cfg in observation_space_cfg["options"]["links"]:
|
||||
# link_ref = link_obs_cfg["link_ref"]
|
||||
# link_obs_list.append(LinkObservation(where=["network", "links", ref_map_links[link_ref]]))
|
||||
|
||||
# acl_obs = AclObservation(
|
||||
# node_ip_to_id=node_ip_to_index,
|
||||
# ports=game_cfg["ports"],
|
||||
# protocols=game_cfg["ports"],
|
||||
# where=["network", "nodes", observation_space_cfg["options"]["acl"]["router_node_ref"]],
|
||||
# )
|
||||
# obs_space = UC2BlueObservation(
|
||||
# nodes=node_obs_list, links=link_obs_list, acl=acl_obs, ics=ICSObservation()
|
||||
# )
|
||||
# elif observation_space_cfg["type"] == "UC2RedObservation":
|
||||
# obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim)
|
||||
# elif observation_space_cfg["type"] == "UC2GreenObservation":
|
||||
# obs_space = UC2GreenObservation.from_config(observation_space_cfg.get('options',{}))
|
||||
# else:
|
||||
# print("observation space config not specified correctly.")
|
||||
# obs_space = NullObservation()
|
||||
"""
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
|
||||
5
src/primaite/utils/start_gate_server.py
Normal file
5
src/primaite/utils/start_gate_server.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
service = GATEService()
|
||||
service.start()
|
||||
Reference in New Issue
Block a user