Remove more GATE stuff
This commit is contained in:
@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from arcd_gate.cli import setup as gate_setup
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import reset_demo_notebooks, reset_example_configs
|
||||
|
||||
@@ -115,9 +113,6 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Setting up ARCD GATE...")
|
||||
gate_setup()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@@ -131,14 +126,8 @@ def session(
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from threading import Thread
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.main import run
|
||||
from primaite.utils.start_gate_server import start_gate_server
|
||||
|
||||
server_thread = Thread(target=start_gate_server)
|
||||
server_thread.start()
|
||||
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
|
||||
@@ -108,7 +108,7 @@ game_config:
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: GATERLAgent
|
||||
type: idk???
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
|
||||
@@ -15,16 +15,12 @@ class PolicyABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
) -> None:
|
||||
def learn(self, n_episodes: int, n_time_steps: int) -> None:
|
||||
"""Train the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(
|
||||
self,
|
||||
) -> None:
|
||||
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -33,35 +33,24 @@ class SB3Policy(PolicyABC):
|
||||
seed=...,
|
||||
) # TODO: populate values once I figure out how to get them from the config / session
|
||||
|
||||
def learn(
|
||||
self,
|
||||
) -> None:
|
||||
def learn(self, n_episodes: int, n_time_steps: int) -> None:
|
||||
"""Train the agent."""
|
||||
time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session
|
||||
episodes = 10 # TODO: populate values once I figure out how to get them from the config / session
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for i in range(n_episodes):
|
||||
self._agent.learn(total_timesteps=n_time_steps)
|
||||
self._save_checkpoint()
|
||||
pass
|
||||
|
||||
def eval(
|
||||
self,
|
||||
) -> None:
|
||||
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session
|
||||
num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session
|
||||
deterministic = True # TODO: populate values once I figure out how to get them from the config / session
|
||||
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for episode in range(num_episodes):
|
||||
for episode in range(n_episodes):
|
||||
obs = self.session.env.reset()
|
||||
for step in range(time_steps):
|
||||
for step in range(n_time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
obs, rewards, truncated, terminated, info = self.session.env.step(action)
|
||||
|
||||
def save(
|
||||
self,
|
||||
) -> None:
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
savepath = (
|
||||
"temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session
|
||||
@@ -69,22 +58,16 @@ class SB3Policy(PolicyABC):
|
||||
self._agent.save(savepath)
|
||||
pass
|
||||
|
||||
def load(
|
||||
self,
|
||||
) -> None:
|
||||
def load(self) -> None:
|
||||
"""Load agent from a checkpoint."""
|
||||
self._agent_class.load("temp/path/to/save.pth", env=self.session.env)
|
||||
pass
|
||||
|
||||
def close(
|
||||
self,
|
||||
) -> None:
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
self,
|
||||
) -> "SB3Policy":
|
||||
def from_config(self) -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
pass
|
||||
|
||||
@@ -96,15 +96,16 @@ class PrimaiteSession:
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
# n_learn_steps = self.training_options.n_learn_steps
|
||||
n_learn_steps = self.training_options.n_learn_steps
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
# n_eval_steps = self.training_options.n_eval_steps
|
||||
n_eval_steps = self.training_options.n_eval_steps
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
deterministic_eval = True # TODO: get this value from config
|
||||
if n_learn_episodes > 0:
|
||||
self.policy.learn()
|
||||
self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps)
|
||||
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval()
|
||||
self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user