Add agent loading

This commit is contained in:
Marek Wolan
2023-11-16 15:40:49 +00:00
parent e52d1fbd45
commit 0861663cc1
3 changed files with 8 additions and 3 deletions

View File

@@ -119,6 +119,7 @@ def setup(overwrite_existing: bool = True) -> None:
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
@@ -132,4 +133,4 @@ def session(
if not config:
config = example_config_path()
print(config)
run(config_path=config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -1,6 +1,7 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
import enlighten
@@ -297,7 +298,7 @@ class PrimaiteSession:
return NotImplemented
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
@@ -516,6 +517,8 @@ class PrimaiteSession:
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
# READ IO SETTINGS
io_settings = cfg.get("io_settings", {})

View File

@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
@@ -31,7 +32,7 @@ def run(
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()