#917 - Began the process of reloading existing agents into the session

This commit is contained in:
Chris McCarthy
2023-06-28 19:54:00 +01:00
parent 1d3778f400
commit 7f912df383
4 changed files with 82 additions and 18 deletions

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
@@ -6,6 +7,8 @@ from pathlib import Path
from typing import Optional, Final, Dict, Union
from uuid import uuid4
import yaml
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.config import lay_down_config
@@ -58,23 +61,34 @@ class AgentSessionABC(ABC):
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = _get_session_path(self.session_timestamp)
"The Session path"
self.learning_path = self.session_path / "learning"
"The learning outputs path"
self.evaluation_path = self.session_path / "evaluation"
"The evaluation outputs path"
self.checkpoints_path = self.learning_path / "checkpoints"
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
"The Session checkpoints path"
self.timestamp_str = self.session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
"The session timestamp as a string"
@property
def timestamp_str(self) -> str:
"""The session timestamp as a string."""
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
@property
def learning_path(self) -> Path:
"""The learning outputs path."""
return self.session_path / "learning"
@property
def evaluation_path(self) -> Path:
"""The evaluation outputs path."""
return self.session_path / "evaluation"
@property
def checkpoints_path(self) -> Path:
"""The Session checkpoints path."""
return self.learning_path / "checkpoints"
@property
def uuid(self):
@@ -104,8 +118,14 @@ class AgentSessionABC(ABC):
"uuid": self.uuid,
"start_datetime": self.session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"learning": {
"total_episodes": None,
"total_time_steps": None
},
"evaluation": {
"total_episodes": None,
"total_time_steps": None
},
"env": {
"training_config": self._training_config.to_dict(
json_serializable=True
@@ -134,8 +154,13 @@ class AgentSessionABC(ABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._env.episode_count
metadata_dict["total_time_steps"] = self._env.total_step_count
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -172,6 +197,7 @@ class AgentSessionABC(ABC):
_LOGGER.debug("Writing transactions")
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
@abstractmethod
def evaluate(
@@ -180,6 +206,7 @@ class AgentSessionABC(ABC):
episodes: Optional[int] = None,
**kwargs
):
self.is_eval = True
_LOGGER.info("Finished evaluation")
@abstractmethod
@@ -188,7 +215,40 @@ class AgentSessionABC(ABC):
@classmethod
@abstractmethod
def load(cls):
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
if not isinstance(path, Path):
path = Path(path)
if path.exists():
# Unpack the session_metadata.json file
md_file = path / "session_metadata.json"
with open(md_file, "r") as file:
md_dict = json.load(file)
# Create a temp directory and dump the training and lay down
# configs into it
temp_dir = path / ".temp"
temp_dir.mkdir(exist_ok=True)
temp_tc = temp_dir / "tc.yaml"
with open(temp_tc, "w") as file:
yaml.dump(md_dict["env"]["training_config"], file)
temp_ldc = temp_dir / "ldc.yaml"
with open(temp_ldc, "w") as file:
yaml.dump(md_dict["env"]["lay_down_config"], file)
agent = cls(temp_tc, temp_ldc)
agent.session_path = path
return agent
else:
# Session path does not exist
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@abstractmethod

View File

@@ -44,6 +44,8 @@ class SB3Agent(AgentSessionABC):
f"{self._training_config.agent_identifier}"
)
self.is_eval = False
def _setup(self):
super()._setup()
self._env = Primaite(
@@ -86,6 +88,7 @@ class SB3Agent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps...")
for i in range(episodes):
@@ -108,6 +111,7 @@ class SB3Agent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:
deterministic_str = "deterministic"
else:

View File

@@ -35,7 +35,7 @@ hard_coded_agent_view: FULL
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
action_type: NODE
# Number of episodes to run per session
num_episodes: 1000

View File

@@ -260,6 +260,8 @@ class Primaite(Env):
learning_session=False
)
self.episode_count = 0
self.step_count = 0
self.total_step_count = 0
def reset(self):
"""
@@ -329,8 +331,6 @@ class Primaite(Env):
# Load the action space into the transaction
transaction.action_space = copy.deepcopy(action)
initial_nodes = copy.deepcopy(self.nodes)
# 1. Implement Blue Action
self.interpret_action_and_apply(action)
# Take snapshots of nodes and links
@@ -383,7 +383,7 @@ class Primaite(Env):
# 5. Calculate reward signal (for RL)
reward = calculate_reward_function(
initial_nodes,
self.nodes_post_pol,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,