Merged PR 199: Fix setup and session commands

## Summary
Make PrimAITE work with the new:
- configs
- CLI commands

## Test process
I ran primaite setup and primaite session to check that they work

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #1987, #1998
This commit is contained in:
Marek Wolan
2023-10-27 15:14:06 +00:00
10 changed files with 85 additions and 246 deletions

View File

@@ -1,22 +0,0 @@
# flake8: noqa
import logging
from primaite import _PRIMAITE_CONFIG, PRIMAITE_PATHS
from primaite.game.session import PrimaiteSession
_PRIMAITE_CONFIG["log_level"] = logging.DEBUG
print(PRIMAITE_PATHS.app_log_dir_path)
import itertools
import yaml
from primaite.game.agent.interface import AbstractAgent
from primaite.game.session import PrimaiteSession
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.sim_container import Simulation
with open("example_config.yaml", "r") as file:
cfg = yaml.safe_load(file)
sess = PrimaiteSession.from_config(cfg)
sess.start_session()

View File

@@ -10,7 +10,6 @@ import yaml
from typing_extensions import Annotated
from primaite import PRIMAITE_PATHS
from primaite.data_viz import PlotlyTemplate
app = typer.Typer()
@@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) ->
print(f"PrimAITE Log Level: {level}")
@app.command()
def notebooks() -> None:
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
start_jupyter_session()
@app.command()
def version() -> None:
"""Get the installed PrimAITE version number."""
@@ -97,14 +88,6 @@ def version() -> None:
print(primaite.__version__)
@app.command()
def clean_up() -> None:
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
old_installation_clean_up.run()
@app.command()
def setup(overwrite_existing: bool = True) -> None:
"""
@@ -112,8 +95,10 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
from arcd_gate.cli import setup as gate_setup
from primaite import getLogger
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
from primaite.setup import reset_demo_notebooks, reset_example_configs
_LOGGER = getLogger(__name__)
@@ -130,84 +115,32 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
old_installation_clean_up.run()
_LOGGER.info("Setting up ARCD GATE...")
gate_setup()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
tc: Optional[str] = None,
ldc: Optional[str] = None,
load: Optional[str] = None,
legacy_tc: bool = False,
legacy_ldc: bool = False,
config: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from threading import Thread
from primaite.config.load import example_config_path
from primaite.main import run
from primaite.utils.start_gate_server import start_gate_server
if load is not None:
# run a loaded session
run(session_path=load)
server_thread = Thread(target=start_gate_server)
server_thread.start()
else:
# start a new session using tc and ldc
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = dos_very_basic_config_path()
run(
training_config_path=tc,
lay_down_config_path=ldc,
legacy_training_config=legacy_tc,
legacy_lay_down_config=legacy_ldc,
)
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
"""
View or set the plotly template for Session plots.
To View, simply call: primaite plotly-template
To set, call: primaite plotly-template <desired template>
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
"""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:
template = primaite_config["session"]["outputs"]["plots"]["template"]
print(f"PrimAITE plotly template: {template}")
if not config:
config = example_config_path()
print(config)
run(config_path=config)

View File

@@ -0,0 +1,2 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Configuration parameters for running experiments."""

View File

@@ -0,0 +1,45 @@
from pathlib import Path
from typing import Dict, Final, Union
import yaml
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER = getLogger(__name__)
_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config"
def load(file_path: Union[str, Path]) -> Dict:
"""
Read a YAML file and return the contents as a dictionary.
:param file_path: Path to the YAML file.
:type file_path: Union[str, Path]
:return: Config dictionary
:rtype: Dict
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if not file_path.exists():
_LOGGER.error(f"File does not exist: {file_path}")
raise FileNotFoundError(f"File does not exist: {file_path}")
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loaded config from {file_path}")
return config
def example_config_path() -> Path:
"""
Get the path to the example config.
:return: Path to the example config.
:rtype: Path
"""
path = _EXAMPLE_CFG / "example_config.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -5,12 +5,6 @@ class PrimaiteError(Exception):
pass
class RLlibAgentError(PrimaiteError):
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
pass
class NetworkError(PrimaiteError):
"""Raised when an error occurs at the network level."""

View File

@@ -5,6 +5,8 @@ from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import load
from primaite.game.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
@@ -12,11 +14,7 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
config_path: Optional[Union[str, Path]] = "",
) -> None:
"""
Run the PrimAITE Session.
@@ -32,28 +30,17 @@ def run(
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
# session = PrimaiteSession(
# training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
# )
# session.setup()
# session.learn()
# session.evaluate()
return NotImplemented
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess.start_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
parser.add_argument("--config")
args = parser.parse_args()
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
if not args.config:
_LOGGER.error("Please provide a config file using the --config " "argument")
run(session_path=args.config)

View File

View File

@@ -1,107 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.networks import arcd_uc2_network\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = arcd_uc2_network()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### set up some services to test if actions are working"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv = net.get_node_by_hostname('database_server')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.system.services.database_service import DatabaseService"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_svc = DatabaseService(file_system=db_serv.file_system)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv.install_service(db_svc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv.describe_state()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,5 +1,12 @@
"""Utility script to start the gate server for running PrimAITE in attached mode."""
from arcd_gate.server.gate_service import GATEService
service = GATEService()
service.start()
def start_gate_server():
"""Start the gate server."""
service = GATEService()
service.start()
if __name__ == "__main__":
start_gate_server()