Merged PR 199: Fix setup and session commands
## Summary Make PrimAITE work with the new: - configs - CLI commands ## Test process I ran primaite setup and primaite session to check that they work ## Checklist - [x] PR is linked to a **work item** - [x] **acceptance criteria** of linked ticket are met - [x] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [x] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #1987, #1998
This commit is contained in:
22
sandbox.py
22
sandbox.py
@@ -1,22 +0,0 @@
|
||||
# flake8: noqa
|
||||
import logging
|
||||
|
||||
from primaite import _PRIMAITE_CONFIG, PRIMAITE_PATHS
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
_PRIMAITE_CONFIG["log_level"] = logging.DEBUG
|
||||
print(PRIMAITE_PATHS.app_log_dir_path)
|
||||
import itertools
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
with open("example_config.yaml", "r") as file:
|
||||
cfg = yaml.safe_load(file)
|
||||
sess = PrimaiteSession.from_config(cfg)
|
||||
|
||||
sess.start_session()
|
||||
@@ -10,7 +10,6 @@ import yaml
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.data_viz import PlotlyTemplate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) ->
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
start_jupyter_session()
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
@@ -97,14 +88,6 @@ def version() -> None:
|
||||
print(primaite.__version__)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
old_installation_clean_up.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
@@ -112,8 +95,10 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from arcd_gate.cli import setup as gate_setup
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
|
||||
from primaite.setup import reset_demo_notebooks, reset_example_configs
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -130,84 +115,32 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
|
||||
old_installation_clean_up.run()
|
||||
_LOGGER.info("Setting up ARCD GATE...")
|
||||
gate_setup()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
tc: Optional[str] = None,
|
||||
ldc: Optional[str] = None,
|
||||
load: Optional[str] = None,
|
||||
legacy_tc: bool = False,
|
||||
legacy_ldc: bool = False,
|
||||
config: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
|
||||
load: The directory of a previous session. Optional. If no value is passed, then the session
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
|
||||
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
|
||||
|
||||
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from threading import Thread
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.main import run
|
||||
from primaite.utils.start_gate_server import start_gate_server
|
||||
|
||||
if load is not None:
|
||||
# run a loaded session
|
||||
run(session_path=load)
|
||||
server_thread = Thread(target=start_gate_server)
|
||||
server_thread.start()
|
||||
|
||||
else:
|
||||
# start a new session using tc and ldc
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(
|
||||
training_config_path=tc,
|
||||
lay_down_config_path=ldc,
|
||||
legacy_training_config=legacy_tc,
|
||||
legacy_lay_down_config=legacy_ldc,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
To set, call: primaite plotly-template <desired template>
|
||||
|
||||
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"]["template"]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
print(config)
|
||||
run(config_path=config)
|
||||
|
||||
2
src/primaite/config/__init__.py
Normal file
2
src/primaite/config/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Configuration parameters for running experiments."""
|
||||
45
src/primaite/config/load.py
Normal file
45
src/primaite/config/load.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config"
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path]) -> Dict:
|
||||
"""
|
||||
Read a YAML file and return the contents as a dictionary.
|
||||
|
||||
:param file_path: Path to the YAML file.
|
||||
:type file_path: Union[str, Path]
|
||||
:return: Config dictionary
|
||||
:rtype: Dict
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
_LOGGER.error(f"File does not exist: {file_path}")
|
||||
raise FileNotFoundError(f"File does not exist: {file_path}")
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loaded config from {file_path}")
|
||||
return config
|
||||
|
||||
|
||||
def example_config_path() -> Path:
|
||||
"""
|
||||
Get the path to the example config.
|
||||
|
||||
:return: Path to the example config.
|
||||
:rtype: Path
|
||||
"""
|
||||
path = _EXAMPLE_CFG / "example_config.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
@@ -5,12 +5,6 @@ class PrimaiteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RLlibAgentError(PrimaiteError):
|
||||
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(PrimaiteError):
|
||||
"""Raised when an error occurs at the network level."""
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import load
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
@@ -12,11 +14,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
@@ -32,28 +30,17 @@ def run(
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
# session = PrimaiteSession(
|
||||
# training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
# )
|
||||
|
||||
# session.setup()
|
||||
# session.learn()
|
||||
# session.evaluate()
|
||||
return NotImplemented
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
parser.add_argument("--config")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if not args.config:
|
||||
_LOGGER.error("Please provide a config file using the --config " "argument")
|
||||
|
||||
run(session_path=args.config)
|
||||
|
||||
0
src/primaite/notebooks/.gitkeep
Normal file
0
src/primaite/notebooks/.gitkeep
Normal file
@@ -1,107 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.network.networks import arcd_uc2_network\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = arcd_uc2_network()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### set up some services to test if actions are working"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_serv = net.get_node_by_hostname('database_server')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.system.services.database_service import DatabaseService"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_svc = DatabaseService(file_system=db_serv.file_system)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_serv.install_service(db_svc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_serv.describe_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,5 +1,12 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
def start_gate_server():
|
||||
"""Start the gate server."""
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_gate_server()
|
||||
|
||||
Reference in New Issue
Block a user