diff --git a/src/primaite/cli.py b/src/primaite/cli.py index a5b3be46..0f17525e 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ - from arcd_gate.cli import setup as gate_setup - from primaite import getLogger from primaite.setup import reset_demo_notebooks, reset_example_configs @@ -115,9 +113,6 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Rebuilding the example notebooks...") reset_example_configs.run(overwrite_existing=True) - _LOGGER.info("Setting up ARCD GATE...") - gate_setup() - _LOGGER.info("PrimAITE setup complete!") @@ -131,14 +126,8 @@ def session( :param config: The path to the config file. Optional, if None, the example config will be used. :type config: Optional[str] """ - from threading import Thread - from primaite.config.load import example_config_path from primaite.main import run - from primaite.utils.start_gate_server import start_gate_server - - server_thread = Thread(target=start_gate_server) - server_thread.start() if not config: config = example_config_path() diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index ee42cf4f..676028bb 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -108,7 +108,7 @@ game_config: - ref: defender team: BLUE - type: GATERLAgent + type: idk??? observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 8d5a9a08..404d6f31 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -15,16 +15,12 @@ class PolicyABC(ABC): pass @abstractmethod - def learn( - self, - ) -> None: + def learn(self, n_episodes: int, n_time_steps: int) -> None: """Train the agent.""" pass @abstractmethod - def eval( - self, - ) -> None: + def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: """Evaluate the agent.""" pass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 151e860d..2d9da1db 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -33,35 +33,24 @@ class SB3Policy(PolicyABC): seed=..., ) # TODO: populate values once I figure out how to get them from the config / session - def learn( - self, - ) -> None: + def learn(self, n_episodes: int, n_time_steps: int) -> None: """Train the agent.""" - time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session - episodes = 10 # TODO: populate values once I figure out how to get them from the config / session - for i in range(episodes): - self._agent.learn(total_timesteps=time_steps) + # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB + for i in range(n_episodes): + self._agent.learn(total_timesteps=n_time_steps) self._save_checkpoint() pass - def eval( - self, - ) -> None: + def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: """Evaluate the agent.""" - time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session - num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session - deterministic = True # TODO: populate values once I figure out how to get them from the config / session - # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB - for episode in range(num_episodes): + for episode in range(n_episodes): obs = self.session.env.reset() - for step in range(time_steps): + for step in range(n_time_steps): action, _states = self._agent.predict(obs, deterministic=deterministic) obs, rewards, truncated, terminated, info = self.session.env.step(action) - def save( - self, - ) -> None: + def save(self) -> None: """Save the agent.""" savepath = ( "temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session @@ -69,22 +58,16 @@ class SB3Policy(PolicyABC): self._agent.save(savepath) pass - def load( - self, - ) -> None: + def load(self) -> None: """Load agent from a checkpoint.""" self._agent_class.load("temp/path/to/save.pth", env=self.session.env) pass - def close( - self, - ) -> None: + def close(self) -> None: """Close the agent.""" pass @classmethod - def from_config( - self, - ) -> "SB3Policy": + def from_config(self) -> "SB3Policy": """Create an agent from config file.""" pass diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a088d05e..9d241932 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -96,15 +96,16 @@ class PrimaiteSession: def start_session(self) -> None: """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - # n_learn_steps = self.training_options.n_learn_steps + n_learn_steps = self.training_options.n_learn_steps n_learn_episodes = self.training_options.n_learn_episodes - # n_eval_steps = self.training_options.n_eval_steps + n_eval_steps = self.training_options.n_eval_steps n_eval_episodes = self.training_options.n_eval_episodes + deterministic_eval = True # TODO: get this value from config if n_learn_episodes > 0: - self.policy.learn() + self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps) if n_eval_episodes > 0: - self.policy.eval() + self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval) def step(self): """