Remove GATE-related code.
This commit is contained in:
@@ -29,17 +29,6 @@ jobs:
|
||||
pip install -e .[dev]
|
||||
displayName: 'Install PrimAITE for docs autosummary'
|
||||
|
||||
- script: |
|
||||
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
|
||||
python -m pip install $GATE_WHEEL[dev]
|
||||
displayName: 'Install GATE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
displayName: 'Install GATE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
primaite setup
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
@@ -81,17 +81,6 @@ stages:
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
|
||||
python -m pip install $GATE_WHEEL[dev]
|
||||
displayName: 'Install GATE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
displayName: 'Install GATE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
primaite setup
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
@@ -46,7 +46,6 @@ python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
.\.venv\Scripts\activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -75,7 +74,6 @@ cd ~/primaite
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install arcd_gate-0.1.0-py3-none-any.whl
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -120,7 +118,6 @@ source venv/bin/activate
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
pip install arcd_gate-0.1.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
#### 6. Perform the PrimAITE setup:
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GATERLAgent(AbstractGATEAgent):
|
||||
...
|
||||
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
|
||||
# because when we are supporting MARL, the actions form multiple agents will have to be batched
|
||||
|
||||
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
|
||||
# with the actions for those agents.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None,
|
||||
action_space: ActionManager | None,
|
||||
observation_space: ObservationSpace | None,
|
||||
reward_function: RewardFunction | None,
|
||||
) -> None:
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.most_recent_action
|
||||
@@ -76,7 +76,7 @@ class AbstractAgent(ABC):
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
@@ -108,9 +108,3 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
|
||||
...
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from arcd_gate.client.gate_client import ActType, GATEClient
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces.utils import flatten, flatten_space
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -34,111 +30,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
|
||||
|
||||
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
|
||||
"""
|
||||
Create a new GATE client for PrimAITE.
|
||||
|
||||
:param parent_session: The parent session object.
|
||||
:type parent_session: PrimaiteSession
|
||||
:param service_port: The port on which the GATE service is running.
|
||||
:type service_port: int, optional
|
||||
"""
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session: "PrimaiteSession" = parent_session
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
"""The reinforcement learning framework to use."""
|
||||
return self.parent_session.training_options.rl_framework
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
"""The reinforcement learning algorithm to use."""
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> int | None:
|
||||
"""The seed to use for the environment's random number generator."""
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
"""The number of episodes in each learning run."""
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
"""The number of steps in each learning episode."""
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
"""The number of episodes in each evaluation run."""
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
"""The number of steps in each evaluation episode."""
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
"""The gym action space of the agent."""
|
||||
return self.parent_session.rl_agent.action_space.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""The gymnasium observation space of the agent."""
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
"""Take a step in the environment.
|
||||
|
||||
This method is called by GATE to advance the simulation by one timestep.
|
||||
|
||||
:param action: The agent's action.
|
||||
:type action: ActType
|
||||
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
|
||||
:rtype: Tuple[ObsType, float, bool, bool, Dict]
|
||||
"""
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate(state)
|
||||
term = False
|
||||
trunc = False
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment.
|
||||
|
||||
This method is called when the environment is initialized and at the end of each episode.
|
||||
|
||||
:param seed: The seed to use for the environment's random number generator.
|
||||
:type seed: int, optional
|
||||
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
|
||||
compatibility with GATE.
|
||||
:type options: dict[str, Any], optional
|
||||
:return: The initial observation and an empty info dictionary.
|
||||
:rtype: Tuple[ObsType, Dict]
|
||||
"""
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
return obs, {}
|
||||
|
||||
def close(self):
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
self.parent_session.close()
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
@@ -189,12 +80,10 @@ class PrimaiteSession:
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
self.gate_client.start()
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
|
||||
def start_gate_server():
|
||||
"""Start the gate server."""
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_gate_server()
|
||||
Reference in New Issue
Block a user