Remove more GATE stuff

This commit is contained in:
Marek Wolan
2023-11-13 17:12:50 +00:00
parent 08e88e52b0
commit 1cb54da2dd
5 changed files with 19 additions and 50 deletions

View File

@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
from arcd_gate.cli import setup as gate_setup
from primaite import getLogger
from primaite.setup import reset_demo_notebooks, reset_example_configs
@@ -115,9 +113,6 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Setting up ARCD GATE...")
gate_setup()
_LOGGER.info("PrimAITE setup complete!")
@@ -131,14 +126,8 @@ def session(
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from threading import Thread
from primaite.config.load import example_config_path
from primaite.main import run
from primaite.utils.start_gate_server import start_gate_server
server_thread = Thread(target=start_gate_server)
server_thread.start()
if not config:
config = example_config_path()

View File

@@ -108,7 +108,7 @@ game_config:
- ref: defender
team: BLUE
type: GATERLAgent
type: idk???
observation_space:
type: UC2BlueObservation

View File

@@ -15,16 +15,12 @@ class PolicyABC(ABC):
pass
@abstractmethod
def learn(
self,
) -> None:
def learn(self, n_episodes: int, n_time_steps: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(
self,
) -> None:
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass

View File

@@ -33,35 +33,24 @@ class SB3Policy(PolicyABC):
seed=...,
) # TODO: populate values once I figure out how to get them from the config / session
def learn(
self,
) -> None:
def learn(self, n_episodes: int, n_time_steps: int) -> None:
"""Train the agent."""
time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session
episodes = 10 # TODO: populate values once I figure out how to get them from the config / session
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for i in range(n_episodes):
self._agent.learn(total_timesteps=n_time_steps)
self._save_checkpoint()
pass
def eval(
self,
) -> None:
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
"""Evaluate the agent."""
time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session
num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session
deterministic = True # TODO: populate values once I figure out how to get them from the config / session
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for episode in range(num_episodes):
for episode in range(n_episodes):
obs = self.session.env.reset()
for step in range(time_steps):
for step in range(n_time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
obs, rewards, truncated, terminated, info = self.session.env.step(action)
def save(
self,
) -> None:
def save(self) -> None:
"""Save the agent."""
savepath = (
"temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session
@@ -69,22 +58,16 @@ class SB3Policy(PolicyABC):
self._agent.save(savepath)
pass
def load(
self,
) -> None:
def load(self) -> None:
"""Load agent from a checkpoint."""
self._agent_class.load("temp/path/to/save.pth", env=self.session.env)
pass
def close(
self,
) -> None:
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(
self,
) -> "SB3Policy":
def from_config(self) -> "SB3Policy":
"""Create an agent from config file."""
pass

View File

@@ -96,15 +96,16 @@ class PrimaiteSession:
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
# n_learn_steps = self.training_options.n_learn_steps
n_learn_steps = self.training_options.n_learn_steps
n_learn_episodes = self.training_options.n_learn_episodes
# n_eval_steps = self.training_options.n_eval_steps
n_eval_steps = self.training_options.n_eval_steps
n_eval_episodes = self.training_options.n_eval_episodes
deterministic_eval = True # TODO: get this value from config
if n_learn_episodes > 0:
self.policy.learn()
self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps)
if n_eval_episodes > 0:
self.policy.eval()
self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval)
def step(self):
"""