Fix CLI and Session to work with new classes
This commit is contained in:
@@ -123,37 +123,19 @@ def session(
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
|
||||
load: The directory of a previous session. Optional. If no value is passed, then the session
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
|
||||
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
|
||||
|
||||
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from threading import Thread
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.main import run
|
||||
from primaite.utils.start_gate_server import start_gate_server
|
||||
|
||||
else:
|
||||
# start a new session using tc and ldc
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
server_thread = Thread(target=start_gate_server)
|
||||
server_thread.start()
|
||||
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(
|
||||
training_config_path=tc,
|
||||
lay_down_config_path=ldc,
|
||||
legacy_training_config=legacy_tc,
|
||||
legacy_lay_down_config=legacy_ldc,
|
||||
)
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
print(config)
|
||||
run(config_path=config)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config"
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path]) -> Dict:
|
||||
"""
|
||||
@@ -17,6 +19,27 @@ def load(file_path: Union[str, Path]) -> Dict:
|
||||
:return: Config dictionary
|
||||
:rtype: Dict
|
||||
"""
|
||||
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
_LOGGER.error(f"File does not exist: {file_path}")
|
||||
raise FileNotFoundError(f"File does not exist: {file_path}")
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loaded config from {file_path}")
|
||||
return config
|
||||
|
||||
|
||||
def example_config_path() -> Path:
|
||||
"""
|
||||
Get the path to the example config.
|
||||
|
||||
:return: Path to the example config.
|
||||
:rtype: Path
|
||||
"""
|
||||
path = _EXAMPLE_CFG / "example_config.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
|
||||
@@ -5,12 +5,6 @@ class PrimaiteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RLlibAgentError(PrimaiteError):
|
||||
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(PrimaiteError):
|
||||
"""Raised when an error occurs at the network level."""
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import load
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
@@ -12,11 +14,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
@@ -32,28 +30,17 @@ def run(
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
# session = PrimaiteSession(
|
||||
# training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
# )
|
||||
|
||||
# session.setup()
|
||||
# session.learn()
|
||||
# session.evaluate()
|
||||
return NotImplemented
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
parser.add_argument("--config")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if not args.config:
|
||||
_LOGGER.error("Please provide a config file using the --config " "argument")
|
||||
|
||||
run(session_path=args.config)
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
def start_gate_server():
|
||||
"""Start the gate server."""
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_gate_server()
|
||||
|
||||
Reference in New Issue
Block a user