From b6d93ad33f47d31e9e061ca9d10533b1cf6d8dcf Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Jun 2023 19:54:00 +0100 Subject: [PATCH] #917 - Began the process of reloading existing agents into the session --- src/primaite/agents/agent.py | 88 ++++++++++++++++--- src/primaite/agents/sb3.py | 4 + .../training/training_config_main.yaml | 2 +- src/primaite/environment/primaite_env.py | 6 +- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 05133b7e..f545a3cb 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json import time from abc import ABC, abstractmethod @@ -6,6 +7,8 @@ from pathlib import Path from typing import Optional, Final, Dict, Union from uuid import uuid4 +import yaml + import primaite from primaite import getLogger, SESSIONS_DIR from primaite.config import lay_down_config @@ -58,23 +61,34 @@ class AgentSessionABC(ABC): self._agent = None self._can_learn: bool = False self._can_evaluate: bool = False + self.is_eval = False self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() "The session timestamp" self.session_path = _get_session_path(self.session_timestamp) "The Session path" - self.learning_path = self.session_path / "learning" - "The learning outputs path" - self.evaluation_path = self.session_path / "evaluation" - "The evaluation outputs path" - self.checkpoints_path = self.learning_path / "checkpoints" self.checkpoints_path.mkdir(parents=True, exist_ok=True) - "The Session checkpoints path" - self.timestamp_str = self.session_timestamp.strftime( - "%Y-%m-%d_%H-%M-%S") - "The session timestamp as a string" + @property + def timestamp_str(self) -> str: + """The session timestamp as a string.""" + return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + + @property + def learning_path(self) -> Path: + """The learning outputs path.""" + return self.session_path / "learning" + + @property + def evaluation_path(self) -> Path: + """The evaluation outputs path.""" + return self.session_path / "evaluation" + + @property + def checkpoints_path(self) -> Path: + """The Session checkpoints path.""" + return self.learning_path / "checkpoints" @property def uuid(self): @@ -104,8 +118,14 @@ class AgentSessionABC(ABC): "uuid": self.uuid, "start_datetime": self.session_timestamp.isoformat(), "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, + "learning": { + "total_episodes": None, + "total_time_steps": None + }, + "evaluation": { + "total_episodes": None, + "total_time_steps": None + }, "env": { "training_config": self._training_config.to_dict( json_serializable=True @@ -134,8 +154,13 @@ class AgentSessionABC(ABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._env.episode_count - metadata_dict["total_time_steps"] = self._env.total_step_count + + if not self.is_eval: + metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa + else: + metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -172,6 +197,7 @@ class AgentSessionABC(ABC): _LOGGER.debug("Writing transactions") self._update_session_metadata_file() self._can_evaluate = True + self.is_eval = False @abstractmethod def evaluate( @@ -180,6 +206,7 @@ class AgentSessionABC(ABC): episodes: Optional[int] = None, **kwargs ): + self.is_eval = True _LOGGER.info("Finished evaluation") @abstractmethod @@ -188,7 +215,40 @@ class AgentSessionABC(ABC): @classmethod @abstractmethod - def load(cls): + def load(cls, path: Union[str, Path]) -> AgentSessionABC: + if not isinstance(path, Path): + path = Path(path) + + if path.exists(): + # Unpack the session_metadata.json file + md_file = path / "session_metadata.json" + with open(md_file, "r") as file: + md_dict = json.load(file) + + # Create a temp directory and dump the training and lay down + # configs into it + temp_dir = path / ".temp" + temp_dir.mkdir(exist_ok=True) + + temp_tc = temp_dir / "tc.yaml" + with open(temp_tc, "w") as file: + yaml.dump(md_dict["env"]["training_config"], file) + + temp_ldc = temp_dir / "ldc.yaml" + with open(temp_ldc, "w") as file: + yaml.dump(md_dict["env"]["lay_down_config"], file) + + agent = cls(temp_tc, temp_ldc) + + agent.session_path = path + + return agent + + else: + # Session path does not exist + msg = f"Failed to load PrimAITE Session, path does not exist: {path}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) pass @abstractmethod diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index c183c544..328e6286 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -44,6 +44,8 @@ class SB3Agent(AgentSessionABC): f"{self._training_config.agent_identifier}" ) + self.is_eval = False + def _setup(self): super()._setup() self._env = Primaite( @@ -86,6 +88,7 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes + self.is_eval = False _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): @@ -108,6 +111,7 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes self._env.set_as_eval() + self.is_eval = True if deterministic: deterministic_str = "deterministic" else: diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 0e0212f4..3cccbcae 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -35,7 +35,7 @@ hard_coded_agent_view: FULL # "NODE" # "ACL" # "ANY" node and acl actions -action_type: ANY +action_type: NODE # Number of episodes to run per session num_episodes: 1000 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5b344a99..e43dc8a5 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -260,6 +260,8 @@ class Primaite(Env): learning_session=False ) self.episode_count = 0 + self.step_count = 0 + self.total_step_count = 0 def reset(self): """ @@ -329,8 +331,6 @@ class Primaite(Env): # Load the action space into the transaction transaction.action_space = copy.deepcopy(action) - initial_nodes = copy.deepcopy(self.nodes) - # 1. Implement Blue Action self.interpret_action_and_apply(action) # Take snapshots of nodes and links @@ -383,7 +383,7 @@ class Primaite(Env): # 5. Calculate reward signal (for RL) reward = calculate_reward_function( - initial_nodes, + self.nodes_post_pol, self.nodes_post_red, self.nodes_reference, self.green_iers,