Add agent loading
This commit is contained in:
@@ -119,6 +119,7 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
@app.command()
|
||||
def session(
|
||||
config: Optional[str] = None,
|
||||
agent_load_file: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
@@ -132,4 +133,4 @@ def session(
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
print(config)
|
||||
run(config_path=config)
|
||||
run(config_path=config, agent_load_path=agent_load_file)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
|
||||
import enlighten
|
||||
@@ -297,7 +298,7 @@ class PrimaiteSession:
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
@@ -516,6 +517,8 @@ class PrimaiteSession:
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
# READ IO SETTINGS
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
|
||||
@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
def run(
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
agent_load_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
@@ -31,7 +32,7 @@ def run(
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user