#1595: Moved hardcoded agent into its own file

This commit is contained in:
Czar.Echavez
2023-07-11 15:03:02 +01:00
parent 0c63d197e5
commit baa14b6cd7
5 changed files with 111 additions and 104 deletions

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
@@ -308,103 +307,3 @@ class AgentSessionABC(ABC):
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,108 @@
import time
from abc import abstractmethod
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path)
self._setup()
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
def _get_latest_checkpoint(self):
pass
def learn(
self,
**kwargs,
):
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls, path=None):
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -4,7 +4,7 @@ import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,

View File

@@ -1,6 +1,6 @@
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable

View File

@@ -1,4 +1,4 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.hardcoded import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum