#917 - Began the process of reloading existing agents into the session
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -6,6 +7,8 @@ from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config import lay_down_config
|
||||
@@ -58,23 +61,34 @@ class AgentSessionABC(ABC):
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = _get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
self.learning_path = self.session_path / "learning"
|
||||
"The learning outputs path"
|
||||
self.evaluation_path = self.session_path / "evaluation"
|
||||
"The evaluation outputs path"
|
||||
self.checkpoints_path = self.learning_path / "checkpoints"
|
||||
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||
"The Session checkpoints path"
|
||||
|
||||
self.timestamp_str = self.session_timestamp.strftime(
|
||||
"%Y-%m-%d_%H-%M-%S")
|
||||
"The session timestamp as a string"
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
"""The session timestamp as a string."""
|
||||
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@property
|
||||
def learning_path(self) -> Path:
|
||||
"""The learning outputs path."""
|
||||
return self.session_path / "learning"
|
||||
|
||||
@property
|
||||
def evaluation_path(self) -> Path:
|
||||
"""The evaluation outputs path."""
|
||||
return self.session_path / "evaluation"
|
||||
|
||||
@property
|
||||
def checkpoints_path(self) -> Path:
|
||||
"""The Session checkpoints path."""
|
||||
return self.learning_path / "checkpoints"
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
@@ -104,8 +118,14 @@ class AgentSessionABC(ABC):
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None,
|
||||
"learning": {
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None
|
||||
},
|
||||
"evaluation": {
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None
|
||||
},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(
|
||||
json_serializable=True
|
||||
@@ -134,8 +154,13 @@ class AgentSessionABC(ABC):
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._env.episode_count
|
||||
metadata_dict["total_time_steps"] = self._env.total_step_count
|
||||
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -172,6 +197,7 @@ class AgentSessionABC(ABC):
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
@@ -180,6 +206,7 @@ class AgentSessionABC(ABC):
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.is_eval = True
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
@@ -188,7 +215,40 @@ class AgentSessionABC(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if path.exists():
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
agent = cls(temp_tc, temp_ldc)
|
||||
|
||||
agent.session_path = path
|
||||
|
||||
return agent
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -44,6 +44,8 @@ class SB3Agent(AgentSessionABC):
|
||||
f"{self._training_config.agent_identifier}"
|
||||
)
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
self._env = Primaite(
|
||||
@@ -86,6 +88,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
@@ -108,6 +111,7 @@ class SB3Agent(AgentSessionABC):
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
|
||||
@@ -35,7 +35,7 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
action_type: NODE
|
||||
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 1000
|
||||
|
||||
@@ -260,6 +260,8 @@ class Primaite(Env):
|
||||
learning_session=False
|
||||
)
|
||||
self.episode_count = 0
|
||||
self.step_count = 0
|
||||
self.total_step_count = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -329,8 +331,6 @@ class Primaite(Env):
|
||||
# Load the action space into the transaction
|
||||
transaction.action_space = copy.deepcopy(action)
|
||||
|
||||
initial_nodes = copy.deepcopy(self.nodes)
|
||||
|
||||
# 1. Implement Blue Action
|
||||
self.interpret_action_and_apply(action)
|
||||
# Take snapshots of nodes and links
|
||||
@@ -383,7 +383,7 @@ class Primaite(Env):
|
||||
|
||||
# 5. Calculate reward signal (for RL)
|
||||
reward = calculate_reward_function(
|
||||
initial_nodes,
|
||||
self.nodes_post_pol,
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
|
||||
Reference in New Issue
Block a user