From baa14b6cd7ba85805970bbffbb047c7ae7a6bac5 Mon Sep 17 00:00:00 2001 From: "Czar.Echavez" Date: Tue, 11 Jul 2023 15:03:02 +0100 Subject: [PATCH] #1595: Moved hardcoded agent into its own file --- src/primaite/agents/agent.py | 101 ------------------------ src/primaite/agents/hardcoded.py | 108 ++++++++++++++++++++++++++ src/primaite/agents/hardcoded_acl.py | 2 +- src/primaite/agents/hardcoded_node.py | 2 +- src/primaite/agents/simple.py | 2 +- 5 files changed, 111 insertions(+), 104 deletions(-) create mode 100644 src/primaite/agents/hardcoded.py diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index a9bdfb1e..2bb24d62 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path @@ -308,103 +307,3 @@ class AgentSessionABC(ABC): fig = plot_av_reward_per_episode(path, title, subtitle) fig.write_image(image_path) _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") - - -class HardCodedAgentSessionABC(AgentSessionABC): - """ - An Agent Session ABC for evaluation deterministic agents. - - This class cannot be directly instantiated and must be inherited from with all implemented abstract methods - implemented. - """ - - def __init__(self, training_config_path, lay_down_config_path): - """ - Initialise a hardcoded agent session. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - """ - super().__init__(training_config_path, lay_down_config_path) - self._setup() - - def _setup(self): - self._env: Primaite = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ) - super()._setup() - self._can_learn = False - self._can_evaluate = True - - def _save_checkpoint(self): - pass - - def _get_latest_checkpoint(self): - pass - - def learn( - self, - **kwargs, - ): - """ - Train the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - _LOGGER.warning("Deterministic agents cannot learn") - - @abstractmethod - def _calculate_action(self, obs): - pass - - def evaluate( - self, - **kwargs, - ): - """ - Evaluate the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - self._env.set_as_eval() # noqa - self.is_eval = True - - time_steps = self._training_config.num_steps - episodes = self._training_config.num_episodes - - obs = self._env.reset() - for episode in range(episodes): - # Reset env and collect initial observation - for step in range(time_steps): - # Calculate action - action = self._calculate_action(obs) - - # Perform the step - obs, reward, done, info = self._env.step(action) - - if done: - break - - # Introduce a delay between steps - time.sleep(self._training_config.time_delay / 1000) - obs = self._env.reset() - self._env.close() - - @classmethod - def load(cls): - """Load an agent from file.""" - _LOGGER.warning("Deterministic agents cannot be loaded") - - def save(self): - """Save the agent.""" - _LOGGER.warning("Deterministic agents cannot be saved") - - def export(self): - """Export the agent to transportable file format.""" - _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded.py b/src/primaite/agents/hardcoded.py new file mode 100644 index 00000000..26972726 --- /dev/null +++ b/src/primaite/agents/hardcoded.py @@ -0,0 +1,108 @@ +import time +from abc import abstractmethod + +from primaite import getLogger +from primaite.agents.agent import AgentSessionABC +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +class HardCodedAgentSessionABC(AgentSessionABC): + """ + An Agent Session ABC for evaluation deterministic agents. + + This class cannot be directly instantiated and must be inherited from with all implemented abstract methods + implemented. + """ + + def __init__(self, training_config_path, lay_down_config_path): + """ + Initialise a hardcoded agent session. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + """ + super().__init__(training_config_path, lay_down_config_path) + self._setup() + + def _setup(self): + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + super()._setup() + self._can_learn = False + self._can_evaluate = True + + def _save_checkpoint(self): + pass + + def _get_latest_checkpoint(self): + pass + + def learn( + self, + **kwargs, + ): + """ + Train the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def _calculate_action(self, obs): + pass + + def evaluate( + self, + **kwargs, + ): + """ + Evaluate the agent. + + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes + + obs = self._env.reset() + for episode in range(episodes): + # Reset env and collect initial observation + for step in range(time_steps): + # Calculate action + action = self._calculate_action(obs) + + # Perform the step + obs, reward, done, info = self._env.step(action) + + if done: + break + + # Introduce a delay between steps + time.sleep(self._training_config.time_delay / 1000) + obs = self._env.reset() + self._env.close() + + @classmethod + def load(cls, path=None): + """Load an agent from file.""" + _LOGGER.warning("Deterministic agents cannot be loaded") + + def save(self): + """Save the agent.""" + _LOGGER.warning("Deterministic agents cannot be saved") + + def export(self): + """Export the agent to transportable file format.""" + _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 166ff415..d736a378 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -4,7 +4,7 @@ import numpy as np from primaite.acl.access_control_list import AccessControlList from primaite.acl.acl_rule import ACLRule -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, get_node_of_ip, diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index c00cf421..757f31da 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,6 +1,6 @@ import numpy as np -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index b429a2f5..230cd5e7 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,4 +1,4 @@ -from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.hardcoded import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum