diff --git a/sandbox.py b/sandbox.py deleted file mode 100644 index ab5e701f..00000000 --- a/sandbox.py +++ /dev/null @@ -1,22 +0,0 @@ -# flake8: noqa -import logging - -from primaite import _PRIMAITE_CONFIG, PRIMAITE_PATHS -from primaite.game.session import PrimaiteSession - -_PRIMAITE_CONFIG["log_level"] = logging.DEBUG -print(PRIMAITE_PATHS.app_log_dir_path) -import itertools - -import yaml - -from primaite.game.agent.interface import AbstractAgent -from primaite.game.session import PrimaiteSession -from primaite.simulator.network.networks import arcd_uc2_network -from primaite.simulator.sim_container import Simulation - -with open("example_config.yaml", "r") as file: - cfg = yaml.safe_load(file) -sess = PrimaiteSession.from_config(cfg) - -sess.start_session() diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 9bdc414d..a5b3be46 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -10,7 +10,6 @@ import yaml from typing_extensions import Annotated from primaite import PRIMAITE_PATHS -from primaite.data_viz import PlotlyTemplate app = typer.Typer() @@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> print(f"PrimAITE Log Level: {level}") -@app.command() -def notebooks() -> None: - """Start Jupyter Lab in the users PrimAITE notebooks directory.""" - from primaite.notebooks import start_jupyter_session - - start_jupyter_session() - - @app.command() def version() -> None: """Get the installed PrimAITE version number.""" @@ -97,14 +88,6 @@ def version() -> None: print(primaite.__version__) -@app.command() -def clean_up() -> None: - """Cleans up left over files from previous version installations.""" - from primaite.setup import old_installation_clean_up - - old_installation_clean_up.run() - - @app.command() def setup(overwrite_existing: bool = True) -> None: """ @@ -112,8 +95,10 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ + from arcd_gate.cli import setup as gate_setup + from primaite import getLogger - from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs + from primaite.setup import reset_demo_notebooks, reset_example_configs _LOGGER = getLogger(__name__) @@ -130,84 +115,32 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Rebuilding the example notebooks...") reset_example_configs.run(overwrite_existing=True) - _LOGGER.info("Performing a clean-up of previous PrimAITE installations...") - old_installation_clean_up.run() + _LOGGER.info("Setting up ARCD GATE...") + gate_setup() _LOGGER.info("PrimAITE setup complete!") @app.command() def session( - tc: Optional[str] = None, - ldc: Optional[str] = None, - load: Optional[str] = None, - legacy_tc: bool = False, - legacy_ldc: bool = False, + config: Optional[str] = None, ) -> None: """ Run a PrimAITE session. - tc: The training config filepath. Optional. If no value is passed then - example default training config is used from: - ~/primaite/2.0.0/config/example_config/training/training_config_main.yaml. - - ldc: The lay down config file path. Optional. If no value is passed then - example default lay down config is used from: - ~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. - - load: The directory of a previous session. Optional. If no value is passed, then the session - will use the default training config and laydown config. Inversely, if a training config and laydown config - is passed while a session directory is passed, PrimAITE will load the session and ignore the training config - and laydown config. - - legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0. - - legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0. + :param config: The path to the config file. Optional, if None, the example config will be used. + :type config: Optional[str] """ - from primaite.config.lay_down_config import dos_very_basic_config_path - from primaite.config.training_config import main_training_config_path + from threading import Thread + + from primaite.config.load import example_config_path from primaite.main import run + from primaite.utils.start_gate_server import start_gate_server - if load is not None: - # run a loaded session - run(session_path=load) + server_thread = Thread(target=start_gate_server) + server_thread.start() - else: - # start a new session using tc and ldc - if not tc: - tc = main_training_config_path() - - if not ldc: - ldc = dos_very_basic_config_path() - - run( - training_config_path=tc, - lay_down_config_path=ldc, - legacy_training_config=legacy_tc, - legacy_lay_down_config=legacy_ldc, - ) - - -@app.command() -def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None: - """ - View or set the plotly template for Session plots. - - To View, simply call: primaite plotly-template - - To set, call: primaite plotly-template - - For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK - """ - if PRIMAITE_PATHS.app_config_file_path.exists(): - with open(PRIMAITE_PATHS.app_config_file_path, "r") as file: - primaite_config = yaml.safe_load(file) - - if template: - primaite_config["session"]["outputs"]["plots"]["template"] = template.value - with open(PRIMAITE_PATHS.app_config_file_path, "w") as file: - yaml.dump(primaite_config, file) - print(f"PrimAITE plotly template: {template.value}") - else: - template = primaite_config["session"]["outputs"]["plots"]["template"] - print(f"PrimAITE plotly template: {template}") + if not config: + config = example_config_path() + print(config) + run(config_path=config) diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py new file mode 100644 index 00000000..92f5a7d2 --- /dev/null +++ b/src/primaite/config/__init__.py @@ -0,0 +1,2 @@ +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +"""Configuration parameters for running experiments.""" diff --git a/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml similarity index 100% rename from example_config.yaml rename to src/primaite/config/_package_data/example_config.yaml diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py new file mode 100644 index 00000000..b01eb129 --- /dev/null +++ b/src/primaite/config/load.py @@ -0,0 +1,45 @@ +from pathlib import Path +from typing import Dict, Final, Union + +import yaml + +from primaite import getLogger, PRIMAITE_PATHS + +_LOGGER = getLogger(__name__) + +_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" + + +def load(file_path: Union[str, Path]) -> Dict: + """ + Read a YAML file and return the contents as a dictionary. + + :param file_path: Path to the YAML file. + :type file_path: Union[str, Path] + :return: Config dictionary + :rtype: Dict + """ + if not isinstance(file_path, Path): + file_path = Path(file_path) + if not file_path.exists(): + _LOGGER.error(f"File does not exist: {file_path}") + raise FileNotFoundError(f"File does not exist: {file_path}") + with open(file_path, "r") as file: + config = yaml.safe_load(file) + _LOGGER.debug(f"Loaded config from {file_path}") + return config + + +def example_config_path() -> Path: + """ + Get the path to the example config. + + :return: Path to the example config. + :rtype: Path + """ + path = _EXAMPLE_CFG / "example_config.yaml" + if not path.exists(): + msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" + _LOGGER.error(msg) + raise FileNotFoundError(msg) + return path diff --git a/src/primaite/exceptions.py b/src/primaite/exceptions.py index 025f6d41..6aa140ba 100644 --- a/src/primaite/exceptions.py +++ b/src/primaite/exceptions.py @@ -5,12 +5,6 @@ class PrimaiteError(Exception): pass -class RLlibAgentError(PrimaiteError): - """Raised when there is a generic error with a RLlib agent that is specific to PRimAITE.""" - - pass - - class NetworkError(PrimaiteError): """Raised when an error occurs at the network level.""" diff --git a/src/primaite/main.py b/src/primaite/main.py index 0cbcff0e..831419d4 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Optional, Union from primaite import getLogger +from primaite.config.load import load +from primaite.game.session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession @@ -12,11 +14,7 @@ _LOGGER = getLogger(__name__) def run( - training_config_path: Optional[Union[str, Path]] = "", - lay_down_config_path: Optional[Union[str, Path]] = "", - session_path: Optional[Union[str, Path]] = None, - legacy_training_config: bool = False, - legacy_lay_down_config: bool = False, + config_path: Optional[Union[str, Path]] = "", ) -> None: """ Run the PrimAITE Session. @@ -32,28 +30,17 @@ def run( :param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0, otherwise False. """ - # session = PrimaiteSession( - # training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config - # ) - - # session.setup() - # session.learn() - # session.evaluate() - return NotImplemented + cfg = load(config_path) + sess = PrimaiteSession.from_config(cfg=cfg) + sess.start_session() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--tc") - parser.add_argument("--ldc") - parser.add_argument("--load") + parser.add_argument("--config") args = parser.parse_args() - if args.load: - run(session_path=args.load) - else: - if not args.tc: - _LOGGER.error("Please provide a training config file using the --tc " "argument") - if not args.ldc: - _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") - run(training_config_path=args.tc, lay_down_config_path=args.ldc) + if not args.config: + _LOGGER.error("Please provide a config file using the --config " "argument") + + run(session_path=args.config) diff --git a/src/primaite/notebooks/.gitkeep b/src/primaite/notebooks/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/notebooks/scratch.ipynb b/src/primaite/notebooks/scratch.ipynb deleted file mode 100644 index 4e873460..00000000 --- a/src/primaite/notebooks/scratch.ipynb +++ /dev/null @@ -1,107 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from primaite.simulator.network.networks import arcd_uc2_network\n", - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "net = arcd_uc2_network()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### set up some services to test if actions are working" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db_serv = net.get_node_by_hostname('database_server')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from primaite.simulator.system.services.database_service import DatabaseService" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db_svc = DatabaseService(file_system=db_serv.file_system)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db_serv.install_service(db_svc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "db_serv.describe_state()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/src/primaite/utils/start_gate_server.py b/src/primaite/utils/start_gate_server.py index 53508cd2..d91952f2 100644 --- a/src/primaite/utils/start_gate_server.py +++ b/src/primaite/utils/start_gate_server.py @@ -1,5 +1,12 @@ """Utility script to start the gate server for running PrimAITE in attached mode.""" from arcd_gate.server.gate_service import GATEService -service = GATEService() -service.start() + +def start_gate_server(): + """Start the gate server.""" + service = GATEService() + service.start() + + +if __name__ == "__main__": + start_gate_server()