#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -5,7 +5,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Dict, Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
@@ -51,16 +51,12 @@ class AgentSessionABC(ABC):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path
|
||||
)
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(
|
||||
self._lay_down_config_path
|
||||
)
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.output_verbose_level = self._training_config.output_verbose_level
|
||||
|
||||
self._env: Primaite
|
||||
@@ -132,9 +128,7 @@ class AgentSessionABC(ABC):
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(
|
||||
json_serializable=True
|
||||
),
|
||||
"training_config": self._training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
@@ -161,19 +155,11 @@ class AgentSessionABC(ABC):
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"][
|
||||
"total_episodes"
|
||||
] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"][
|
||||
"total_time_steps"
|
||||
] = self._env.total_step_count # noqa
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"][
|
||||
"total_episodes"
|
||||
] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"][
|
||||
"total_time_steps"
|
||||
] = self._env.total_step_count # noqa
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -184,12 +170,9 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment "
|
||||
f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(
|
||||
f"The output directory for this session is: {self.session_path}"
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
@@ -201,17 +184,11 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
@@ -225,17 +202,11 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
@@ -281,9 +252,7 @@ class AgentSessionABC(ABC):
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = (
|
||||
f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
)
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
@@ -354,17 +323,11 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
@@ -375,27 +338,19 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
|
||||
@@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
if (
|
||||
self._training_config.hard_coded_agent_view
|
||||
== HardCodedAgentView.BASIC
|
||||
):
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
@@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
):
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
@@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
@@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[
|
||||
protocol - 1
|
||||
] # -1 as dont have to account for ANY in list of services
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
@@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(
|
||||
get_node_of_ip(action_source_ip, self._env.nodes)
|
||||
)
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(
|
||||
get_node_of_ip(
|
||||
action_destination_ip, self._env.nodes
|
||||
)
|
||||
)
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name)
|
||||
+ 1
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
@@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(
|
||||
self._env.green_iers, self._env.acl, self._env.nodes
|
||||
)
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(
|
||||
ier, self._env.acl, self._env.nodes
|
||||
)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(
|
||||
service_name_to_check
|
||||
)
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
@@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len(
|
||||
[i for i in o if i != "NONE"]
|
||||
) # number of nodes (not links)
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [
|
||||
n for n, i in enumerate(service_states) if i == "COMPROMISED"
|
||||
]
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = (
|
||||
np.random.choice(comprimised_states) + 1
|
||||
) # +1 as 0 would be any
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(
|
||||
list(range(1, number_of_nodes + 1)) + ["ANY"]
|
||||
)
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip)
|
||||
if action_destination_ip != "ANY"
|
||||
else action_destination_ip
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_action_node_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
@@ -93,12 +89,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = (
|
||||
"ON" # Why reset it when we can just turn it on
|
||||
)
|
||||
action_service_index = (
|
||||
0 # does nothing isn't relevant for operating state
|
||||
)
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
|
||||
@@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
@@ -14,11 +13,7 @@ from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import (
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
)
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -49,10 +44,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = (
|
||||
f"Expected RLLIB agent_framework, "
|
||||
f"got {self._training_config.agent_framework}"
|
||||
)
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
@@ -60,10 +52,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = (
|
||||
"Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}"
|
||||
)
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: PPOConfig
|
||||
@@ -94,12 +83,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result[
|
||||
"episodes_total"
|
||||
]
|
||||
metadata_dict["total_time_steps"] = self._current_result[
|
||||
"timesteps_total"
|
||||
]
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -122,9 +107,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
@@ -132,72 +115,41 @@ class RLlibAgent(AgentSessionABC):
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(
|
||||
logger_creator=_custom_log_creator(self.session_path)
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (
|
||||
episode_count == self._training_config.num_episodes
|
||||
):
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
# Temporarily override train_batch_size and horizon
|
||||
if time_steps:
|
||||
self._agent_config.train_batch_size = time_steps
|
||||
self._agent_config.horizon = time_steps
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps..."
|
||||
)
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
if (
|
||||
self._training_config.deep_learning_framework
|
||||
!= DeepLearningFramework.TORCH
|
||||
):
|
||||
policy = self._agent.get_policy()
|
||||
tf.compat.v1.summary.FileWriter(
|
||||
self.session_path / "ray_results", policy.get_session().graph
|
||||
)
|
||||
super().learn()
|
||||
self._agent.stop()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -21,10 +21,7 @@ class SB3Agent(AgentSessionABC):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = (
|
||||
f"Expected SB3 agent_framework, "
|
||||
f"got {self._training_config.agent_framework}"
|
||||
)
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
@@ -32,10 +29,7 @@ class SB3Agent(AgentSessionABC):
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = (
|
||||
"Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier}"
|
||||
)
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -64,19 +58,15 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env,
|
||||
verbose=self.output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (
|
||||
episode_count == self._training_config.num_episodes
|
||||
):
|
||||
checkpoint_path = (
|
||||
self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
)
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
@@ -85,58 +75,37 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps..."
|
||||
)
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
|
||||
self.close()
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if deterministic:
|
||||
@@ -144,19 +113,18 @@ class SB3Agent(AgentSessionABC):
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for "
|
||||
f"{episodes} episodes @ {time_steps} time steps..."
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(
|
||||
obs, deterministic=deterministic
|
||||
)
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_action_acl_enum,
|
||||
transform_action_node_enum,
|
||||
)
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
|
||||
@@ -24,9 +24,7 @@ def transform_action_node_readable(action):
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (
|
||||
action_node_property == "OS" or action_node_property == "SERVICE"
|
||||
) and action[2] <= 1:
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -117,11 +115,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
@@ -173,9 +167,7 @@ def transform_change_obs_readable(obs):
|
||||
|
||||
for service in range(3, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [
|
||||
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
|
||||
]
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
@@ -247,9 +239,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(
|
||||
obs1, obs2, num_nodes=10, num_links=10, num_services=1
|
||||
):
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
|
||||
@@ -291,16 +281,9 @@ def _describe_obs_change_helper(obs_change, is_link):
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [
|
||||
i for i in range(1, len(obs_change)) if obs_change[i] != -1
|
||||
]
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [
|
||||
NodePOLType(i).name
|
||||
if i < 3
|
||||
else NodePOLType(3).name + " " + str(i - 3)
|
||||
for i in index_changed
|
||||
]
|
||||
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
@@ -367,9 +350,7 @@ def transform_action_node_readable(action):
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (
|
||||
action_node_property == "OS" or action_node_property == "SERVICE"
|
||||
) and action[2] <= 1:
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -397,9 +378,7 @@ def node_action_description(action):
|
||||
if property_action == "NONE":
|
||||
return ""
|
||||
if node_property == "OPERATING" or node_property == "OS":
|
||||
description = (
|
||||
f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
)
|
||||
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
elif node_property == "SERVICE":
|
||||
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
|
||||
else:
|
||||
@@ -522,11 +501,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
|
||||
Reference in New Issue
Block a user