Fix CLI and Session to work with new classes

This commit is contained in:
Marek Wolan
2023-10-27 14:26:52 +01:00
parent d4eb499729
commit b81c1739f8
6 changed files with 58 additions and 65 deletions

View File

@@ -123,37 +123,19 @@ def session(
"""
Run a PrimAITE session.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from threading import Thread
from primaite.config.load import example_config_path
from primaite.main import run
from primaite.utils.start_gate_server import start_gate_server
else:
# start a new session using tc and ldc
if not tc:
tc = main_training_config_path()
server_thread = Thread(target=start_gate_server)
server_thread.start()
if not ldc:
ldc = dos_very_basic_config_path()
run(
training_config_path=tc,
lay_down_config_path=ldc,
legacy_training_config=legacy_tc,
legacy_lay_down_config=legacy_ldc,
)
if not config:
config = example_config_path()
print(config)
run(config_path=config)

View File

@@ -1,12 +1,14 @@
from pathlib import Path
from typing import Union
from typing import Dict, Final, Union
import yaml
from primaite import getLogger
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER = getLogger(__name__)
_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config"
def load(file_path: Union[str, Path]) -> Dict:
"""
@@ -17,6 +19,27 @@ def load(file_path: Union[str, Path]) -> Dict:
:return: Config dictionary
:rtype: Dict
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if not file_path.exists():
_LOGGER.error(f"File does not exist: {file_path}")
raise FileNotFoundError(f"File does not exist: {file_path}")
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loaded config from {file_path}")
return config
def example_config_path() -> Path:
"""
Get the path to the example config.
:return: Path to the example config.
:rtype: Path
"""
path = _EXAMPLE_CFG / "example_config.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -5,12 +5,6 @@ class PrimaiteError(Exception):
pass
class RLlibAgentError(PrimaiteError):
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
pass
class NetworkError(PrimaiteError):
"""Raised when an error occurs at the network level."""

View File

@@ -5,6 +5,8 @@ from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import load
from primaite.game.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
@@ -12,11 +14,7 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
config_path: Optional[Union[str, Path]] = "",
) -> None:
"""
Run the PrimAITE Session.
@@ -32,28 +30,17 @@ def run(
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
# session = PrimaiteSession(
# training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
# )
# session.setup()
# session.learn()
# session.evaluate()
return NotImplemented
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess.start_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
parser.add_argument("--config")
args = parser.parse_args()
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
if not args.config:
_LOGGER.error("Please provide a config file using the --config " "argument")
run(session_path=args.config)

View File

@@ -1,5 +1,12 @@
"""Utility script to start the gate server for running PrimAITE in attached mode."""
from arcd_gate.server.gate_service import GATEService
service = GATEService()
service.start()
def start_gate_server():
"""Start the gate server."""
service = GATEService()
service.start()
if __name__ == "__main__":
start_gate_server()