diff --git a/src/primaite/cli.py b/src/primaite/cli.py index ea0247ec..43b97022 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -123,37 +123,19 @@ def session( """ Run a PrimAITE session. - tc: The training config filepath. Optional. If no value is passed then - example default training config is used from: - ~/primaite/2.0.0/config/example_config/training/training_config_main.yaml. - - ldc: The lay down config file path. Optional. If no value is passed then - example default lay down config is used from: - ~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. - - load: The directory of a previous session. Optional. If no value is passed, then the session - will use the default training config and laydown config. Inversely, if a training config and laydown config - is passed while a session directory is passed, PrimAITE will load the session and ignore the training config - and laydown config. - - legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0. - - legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0. + :param config: The path to the config file. Optional, if None, the example config will be used. + :type config: Optional[str] """ - from primaite.config.lay_down_config import dos_very_basic_config_path + from threading import Thread + + from primaite.config.load import example_config_path from primaite.main import run + from primaite.utils.start_gate_server import start_gate_server - else: - # start a new session using tc and ldc - if not tc: - tc = main_training_config_path() + server_thread = Thread(target=start_gate_server) + server_thread.start() - if not ldc: - ldc = dos_very_basic_config_path() - - run( - training_config_path=tc, - lay_down_config_path=ldc, - legacy_training_config=legacy_tc, - legacy_lay_down_config=legacy_ldc, - ) + if not config: + config = example_config_path() + print(config) + run(config_path=config) diff --git a/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml similarity index 100% rename from example_config.yaml rename to src/primaite/config/_package_data/example_config.yaml diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index 77b76299..b01eb129 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -1,12 +1,14 @@ from pathlib import Path -from typing import Union +from typing import Dict, Final, Union import yaml -from primaite import getLogger +from primaite import getLogger, PRIMAITE_PATHS _LOGGER = getLogger(__name__) +_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" + def load(file_path: Union[str, Path]) -> Dict: """ @@ -17,6 +19,27 @@ def load(file_path: Union[str, Path]) -> Dict: :return: Config dictionary :rtype: Dict """ - if not isinstance(file_path, Path): file_path = Path(file_path) + if not file_path.exists(): + _LOGGER.error(f"File does not exist: {file_path}") + raise FileNotFoundError(f"File does not exist: {file_path}") + with open(file_path, "r") as file: + config = yaml.safe_load(file) + _LOGGER.debug(f"Loaded config from {file_path}") + return config + + +def example_config_path() -> Path: + """ + Get the path to the example config. + + :return: Path to the example config. + :rtype: Path + """ + path = _EXAMPLE_CFG / "example_config.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path diff --git a/src/primaite/exceptions.py b/src/primaite/exceptions.py index 025f6d41..6aa140ba 100644 --- a/src/primaite/exceptions.py +++ b/src/primaite/exceptions.py @@ -5,12 +5,6 @@ class PrimaiteError(Exception): pass -class RLlibAgentError(PrimaiteError): - """Raised when there is a generic error with a RLlib agent that is specific to PRimAITE.""" - - pass - - class NetworkError(PrimaiteError): """Raised when an error occurs at the network level.""" diff --git a/src/primaite/main.py b/src/primaite/main.py index 0cbcff0e..831419d4 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Optional, Union from primaite import getLogger +from primaite.config.load import load +from primaite.game.session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession @@ -12,11 +14,7 @@ _LOGGER = getLogger(__name__) def run( - training_config_path: Optional[Union[str, Path]] = "", - lay_down_config_path: Optional[Union[str, Path]] = "", - session_path: Optional[Union[str, Path]] = None, - legacy_training_config: bool = False, - legacy_lay_down_config: bool = False, + config_path: Optional[Union[str, Path]] = "", ) -> None: """ Run the PrimAITE Session. @@ -32,28 +30,17 @@ def run( :param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0, otherwise False. """ - # session = PrimaiteSession( - # training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config - # ) - - # session.setup() - # session.learn() - # session.evaluate() - return NotImplemented + cfg = load(config_path) + sess = PrimaiteSession.from_config(cfg=cfg) + sess.start_session() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--tc") - parser.add_argument("--ldc") - parser.add_argument("--load") + parser.add_argument("--config") args = parser.parse_args() - if args.load: - run(session_path=args.load) - else: - if not args.tc: - _LOGGER.error("Please provide a training config file using the --tc " "argument") - if not args.ldc: - _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") - run(training_config_path=args.tc, lay_down_config_path=args.ldc) + if not args.config: + _LOGGER.error("Please provide a config file using the --config " "argument") + + run(session_path=args.config) diff --git a/src/primaite/utils/start_gate_server.py b/src/primaite/utils/start_gate_server.py index 53508cd2..d91952f2 100644 --- a/src/primaite/utils/start_gate_server.py +++ b/src/primaite/utils/start_gate_server.py @@ -1,5 +1,12 @@ """Utility script to start the gate server for running PrimAITE in attached mode.""" from arcd_gate.server.gate_service import GATEService -service = GATEService() -service.start() + +def start_gate_server(): + """Start the gate server.""" + service = GATEService() + service.start() + + +if __name__ == "__main__": + start_gate_server()