diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 0f17525e..81ab2792 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -119,6 +119,7 @@ def setup(overwrite_existing: bool = True) -> None: @app.command() def session( config: Optional[str] = None, + agent_load_file: Optional[str] = None, ) -> None: """ Run a PrimAITE session. @@ -132,4 +133,4 @@ def session( if not config: config = example_config_path() print(config) - run(config_path=config) + run(config_path=config, agent_load_path=agent_load_file) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 88c1e061..f265b7d9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,6 +1,7 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" from enum import Enum from ipaddress import IPv4Address +from pathlib import Path from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple import enlighten @@ -297,7 +298,7 @@ class PrimaiteSession: return NotImplemented @classmethod - def from_config(cls, cfg: dict) -> "PrimaiteSession": + def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary. The config dictionary should have the following top-level keys: @@ -516,6 +517,8 @@ class PrimaiteSession: # CREATE POLICY sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) # READ IO SETTINGS io_settings = cfg.get("io_settings", {}) diff --git a/src/primaite/main.py b/src/primaite/main.py index 831419d4..1699fe51 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__) def run( config_path: Optional[Union[str, Path]] = "", + agent_load_path: Optional[Union[str, Path]] = None, ) -> None: """ Run the PrimAITE Session. @@ -31,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteSession.from_config(cfg=cfg) + sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session()