#917 -Finished integrating all agents to either train (policy agents) or evaluate (hard-coded agents). Still some fixing up to do, tidying up, loading etc. also docs. But this is all now working.

This commit is contained in:
Chris McCarthy
2023-06-20 16:06:55 +01:00
parent 03ae4884e0
commit a2cc4233b5
16 changed files with 1125 additions and 220 deletions

View File

@@ -29,7 +29,7 @@ The environment config file consists of the following attributes:
* SB3 - Stable Baselines3
* RLLIB - Ray RLlib.
* **red_agent_identifier**
* **agent_identifier**
This identifies the agent to use for the session. Select from one of the following:

View File

@@ -1 +1 @@
2.0.0b1
2.0.0rc1

View File

@@ -1,4 +1,5 @@
import json
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
@@ -12,7 +13,6 @@ from primaite.config import training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
@@ -196,50 +196,77 @@ class AgentSessionABC(ABC):
pass
class DeterministicAgentSessionABC(AgentSessionABC):
@abstractmethod
def __init__(
self,
training_config_path,
lay_down_config_path
):
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._env: Primaite
self._agent = None
class HardCodedAgentSessionABC(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
self._setup()
@abstractmethod
def _setup(self):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
episodes: Optional[int] = None,
**kwargs
):
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
pass
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
episodes: Optional[int] = None,
**kwargs
):
pass
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
for episode in range(episodes):
# Reset env and collect initial observation
obs = self._env.reset()
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
self._env.close()
@classmethod
@abstractmethod
def load(cls):
pass
_LOGGER.warning("Deterministic agents cannot be loaded")
@abstractmethod
def save(self):
pass
_LOGGER.warning("Deterministic agents cannot be saved")
@abstractmethod
def export(self):
pass
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -0,0 +1,376 @@
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.enums import HardCodedAgentView
class HardCodedACLAgent(HardCodedAgentSessionABC):
def _calculate_action(self, obs):
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
# full view action using observation space, action
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes):
blocked_green_iers = {}
for green_ier_id, green_ier in green_iers.items():
source_node_id = green_ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = green_ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = green_ier.get_protocol() # e.g. 'TCP'
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(source_node_address, dest_node_address, protocol,
port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get matching ACL rules for an IER.
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address,
dest_node_address, protocol,
port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get blocking ACL rules for an IER.
Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked)
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
blocked_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
blocked_rules[rule_key] = rule_value
return blocked_rules
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get all allowing ACL rules for an IER.
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol,
port, acl,
nodes, services_list):
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
else:
source_node_address = source_node_id
if dest_node_id != "ANY":
dest_node_address = nodes[str(dest_node_id)].ip_address
else:
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[
protocol - 1] # -1 as dont have to account for ANY in list of services
matching_rules = acl.get_relevant_rules(source_node_address,
dest_node_address, protocol,
port)
return matching_rules
def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol,
port, acl,
nodes, services_list):
matching_rules = self.get_matching_acl_rules(source_node_id,
dest_node_id,
protocol, port, acl,
nodes,
services_list)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port,
acl,
nodes, services_list):
matching_rules = self.get_matching_acl_rules(source_node_id,
dest_node_id,
protocol, port, acl,
nodes,
services_list)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
allowed_rules[rule_key] = rule_value
return allowed_rules
def _calculate_action_full_view(self, obs):
"""
Given an observation and the environment calculate a good acl-based action for the blue agent to take
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
- Which ACL rules already exist, - otherwise:
- The agent would perminently get stuck in a loop of performing the same action over and over.
(best action is to block something, but its already blocked but doesn't know this)
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
if it doesnt know what rules exist)
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
based on the observation space. Additionally, potentially in the future, once a node state
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
If a service node becomes compromised there's a decision to make - do we block that service?
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
Cons: Will block a green IER, decreasing the reward
We decide to block the service.
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
"""
#obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
_, _, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# 1. Check if node is compromised. If so we want to block its outwards services
# a. If it is comprimised check if there's an allow rule we should delete.
# cons: might delete a multi-rule from any source node (ANY -> x)
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
found_action = False
for service_num, service_states in enumerate(s):
for x, service_state in enumerate(service_states):
if service_state == "COMPROMISED":
action_source_id = x + 1 # +1 as 0 is any
action_destination_id = "ANY"
action_protocol = service_num + 1 # +1 as 0 is any
action_port = "ANY"
allow_rules = self.get_allow_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
deny_rules = self.get_deny_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
if len(allow_rules) > 0:
# Check if there's an allow rule we should delete
rule = list(allow_rules.values())[0]
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(
get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(
get_node_of_ip(action_destination_ip,
self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(
action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = self._env.ports_list.index(
action_port_name) + 1 # convert port name e.g. '80' to index
found_action = True
break
elif len(deny_rules) > 0:
# TODO OPTIONAL
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
# to create another
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
continue
else:
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
action_decision = "CREATE"
action_permission = "DENY"
break
if found_action:
break
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
# add an Allow rule if the green IER is being blocked.
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
# If there's a DENY rule delete it if:
# - There isn't already a deny rule
# - It doesnt allows a comprimised node to become operational.
# b. Add an ALLOW rule if:
# - There isn't already an allow rule
# - It doesnt allows a comprimised node to become operational
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(
self._env.green_iers, self._env.acl,
self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(ier,
self._env.acl,
self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(
service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
action_decision = "CREATE"
action_permission = "ALLOW"
action_source_id = int(ier.get_source_node_id())
action_destination_id = int(ier.get_dest_node_id())
action_protocol_name = ier.get_protocol()
action_protocol = self._env.services_list.index(
action_protocol_name) + 1 # convert name e.g. 'TCP' to index
action_port_name = ier.get_port()
action_port = self._env.ports_list.index(
action_port_name) + 1 # convert port name e.g. '80' to index
found_action = True
break
if found_action:
action = [
action_decision,
action_permission,
action_source_id,
action_destination_id,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
else:
# If no good/useful action has been found, just perform a nothing action
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
return action
def _calculate_action_basic_view(self, obs):
"""Given an observation calculate a good acl-based action for the blue agent to take
Uses ONLY information from the current observation with NO knowledge of previous actions taken and
NO reward feedback.
We rely on randomness to select the precise action, as we want to block all traffic originating from
a compromised node, without being able to tell:
1. Which ACL rules already exist
1. Which actions the agent has already tried.
There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed.
Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted.
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len(
[i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [n for n, i in enumerate(service_states) if
i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = np.random.choice(
comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(
list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = int(
action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block
# Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port
action_port = np.random.choice(list(range(1, len(s) + 1)))
action = [
action_decision,
action_permission,
action_source_ip,
action_destination_ip,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good/useful action has been found, just perform a nothing action
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, action_dict)
return nothing_action

View File

@@ -0,0 +1,97 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_change_obs_readable,
)
from primaite.agents.utils import (
transform_action_node_enum,
)
class HardCodedNodeAgent(HardCodedAgentSessionABC):
def _calculate_action(self, obs):
"""Given an observation calculate a good node-based action for the blue agent to take"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, os, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# Check in order of most important states (order doesn't currently matter, but it probably should)
# First see if any OS states are compromised
for x, os_state in enumerate(os):
if os_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "OS"
property_action = "PATCHING"
action_service_index = 0 # does nothing isn't relevant for os
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, see if any Services are compromised
# We fix the compromised state before overwhelemd state,
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, See if any services are overwhelmed
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "OVERWHELMED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Finally, turn on any off nodes
for x, operating_state in enumerate(o):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [action_node_id, action_node_property,
property_action, action_service_index]
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good actions, just go with an action that wont do any harm
action_node_id = 1
action_node_property = "NONE"
property_action = "ON"
action_service_index = 0
action = [action_node_id, action_node_property, property_action,
action_service_index]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
return action

View File

@@ -1,21 +1,19 @@
import json
from datetime import datetime
from pathlib import Path
from pathlib import Path
from typing import Optional
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, RedAgentIdentifier
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _env_creator(env_config):
@@ -51,13 +49,13 @@ class RLlibAgent(AgentSessionABC):
f"got {self._training_config.agent_framework}")
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = ("Expected PPO or A2C red_agent_identifier, "
f"got {self._training_config.red_agent_identifier.value}")
msg = ("Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}")
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: PPOConfig
@@ -67,8 +65,8 @@ class RLlibAgent(AgentSessionABC):
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"red_agent_identifier="
f"{self._training_config.red_agent_identifier}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
@@ -117,7 +115,7 @@ class RLlibAgent(AgentSessionABC):
train_batch_size=self._training_config.num_steps
)
self._agent_config.framework(
framework=self._training_config.deep_learning_framework
framework="torch"
)
self._agent_config.rollouts(

View File

@@ -2,12 +2,12 @@ from typing import Optional
import numpy as np
from stable_baselines3 import PPO, A2C
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import RedAgentIdentifier, AgentFramework
from primaite.common.enums import AgentIdentifier, AgentFramework
from primaite.environment.primaite_env import Primaite
from stable_baselines3.ppo import MlpPolicy as PPOMlp
_LOGGER = getLogger(__name__)
@@ -24,13 +24,13 @@ class SB3Agent(AgentSessionABC):
f"got {self._training_config.agent_framework}")
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = ("Expected PPO or A2C red_agent_identifier, "
f"got {self._training_config.red_agent_identifier.value}")
msg = ("Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}")
_LOGGER.error(msg)
raise ValueError(msg)
@@ -40,8 +40,8 @@ class SB3Agent(AgentSessionABC):
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"red_agent_identifier="
f"{self._training_config.red_agent_identifier}"
f"agent_identifier="
f"{self._training_config.agent_identifier}"
)
def _setup(self):
@@ -56,7 +56,7 @@ class SB3Agent(AgentSessionABC):
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self._training_config.output_verbose_level,
verbose=self.output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path
)
@@ -118,6 +118,7 @@ class SB3Agent(AgentSessionABC):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
@classmethod
def load(self):
raise NotImplementedError

View File

@@ -0,0 +1,60 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_action_acl_enum,
transform_action_node_enum,
)
class RandomAgent(HardCodedAgentSessionABC):
"""
A Random Agent.
Get a completely random action from the action space.
"""
def _calculate_action(self, obs):
return self._env.action_space.sample()
class DummyAgent(HardCodedAgentSessionABC):
"""
A Dummy Agent.
All action spaces setup so dummy action is always 0 regardless of action
type used.
"""
def _calculate_action(self, obs):
return 0
class DoNothingACLAgent(HardCodedAgentSessionABC):
"""
A do nothing ACL agent.
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
return nothing_action
class DoNothingNodeAgent(HardCodedAgentSessionABC):
"""
A do nothing Node agent.
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs):
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
# nothing_action should currently always be 0
return nothing_action

View File

@@ -1,4 +1,13 @@
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
import numpy as np
from primaite.common.enums import (
HardwareState,
LinkStatus,
NodeHardwareAction,
NodeSoftwareAction,
SoftwareState,
)
from primaite.common.enums import NodePOLType
def transform_action_node_readable(action):
@@ -125,3 +134,393 @@ def is_valid_acl_action_extra(action):
return False
return True
def transform_change_obs_readable(obs):
"""Transform list of transactions to readable list of each observation property
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
os_states = [SoftwareState(i).name for i in obs[:, 2]]
new_obs = [ids, operating_states, os_states]
for service in range(3, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
new_obs.append(service_states)
return new_obs
def transform_obs_readable(obs):
"""
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
# Convert list of tuples to list of lists
new_obs = [list(i) for i in new_obs]
return new_obs
def convert_to_new_obs(obs, num_nodes=10):
"""Convert original gym Box observation space to new multiDiscrete observation space"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""
Convert to old observation, links filled with 0's as no information is included in new observation space
example:
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
new_obs = array([[ 1, 1, 1, 1],
[ 2, 1, 1, 1],
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
# Add empty links back and add node ID back
s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64)
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
s[:num_nodes, 1:] = reshaped_nodes # put values back in
new_obs = s
# Add links back in
links = obs[-num_links:]
# Links will be added to the last protocol/service slot but they are not specific to that service
new_obs[num_nodes:, -1] = links
return new_obs
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""Return string describing change between two observations
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
list_of_changes = []
for n, row in enumerate(obs1 - obs2):
if row.any() != 0:
relevant_changes = np.where(row != 0, obs2[n], -1)
relevant_changes[0] = obs2[n, 0] # ID is always relevant
is_link = relevant_changes[0] > num_nodes
desc = _describe_obs_change_helper(relevant_changes, is_link)
list_of_changes.append(desc)
change_string = "\n ".join(list_of_changes)
if len(list_of_changes) > 0:
change_string = "\n " + change_string
return change_string
def _describe_obs_change_helper(obs_change, is_link):
""" "
Helper funcion to describe what has changed
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [
NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed
]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
if is_link
else HardwareState(obs_change[i]).name
if i == 1
else SoftwareState(obs_change[i]).name
for i in index_changed
]
if not is_link:
desc = f"ID {obs_change[0]}:"
for NodePOLType, state in list(zip(NodePOLTypes, states)):
desc = desc + " " + NodePOLType + " changed to " + state + "."
else:
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
return desc
def transform_action_node_enum(action):
"""
Convert a node action from readable string format, to enumerated format
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
if action[1] == "OPERATING":
property_action = NodeHardwareAction[action[2]].value
elif action[1] == "OS" or action[1] == "SERVICE":
property_action = NodeSoftwareAction[action[2]].value
else:
property_action = 0
action_service_index = action[3]
new_action = [action_node_id, action_node_property, property_action, action_service_index]
return new_action
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
return new_action
def node_action_description(action):
"""
Generate string describing a node-based action
"""
if isinstance(action[1], (int, np.int64)):
# transform action to readable format
action = transform_action_node_readable(action)
node_id = action[0]
node_property = action[1]
property_action = action[2]
service_id = action[3]
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
elif node_property == "SERVICE":
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
else:
return ""
return description
def transform_action_acl_readable(action):
"""
Transform an ACL action to a more readable format
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def transform_action_acl_enum(action):
"""
Convert a acl action from readable string format, to enumerated format
"""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == "ANY":
new_action[n + 2] = 0
new_action = np.array(new_action)
return new_action
def acl_action_description(action):
"""generate string describing a acl-based action"""
if isinstance(action[0], (int, np.int64)):
# transform action to readable format
action = transform_action_acl_readable(action)
if action[0] == "NONE":
description = "NO ACL RULE APPLIED"
else:
description = (
f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]},"
f" for protocol/service index {action[4]} on port index {action[5]}"
)
return description
def get_node_of_ip(ip, node_dict):
"""
Get the node ID of an IP address
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.ip_address
if node_ip == ip:
return node_key
def is_valid_node_action(action):
"""Is the node action an actual valid action
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action"""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True
def get_new_action(old_action, action_dict):
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0]
old_action can be either node or acl action type
"""
for key, val in action_dict.items():
if list(val) == list(old_action):
return key
# Not all possible actions are included in dict, only valid action are
# if action is not in the dict, its an invalid action so return 0
return 0
def get_action_description(action, action_dict):
"""
Get a string describing/explaining what an action is doing in words
"""
action_array = action_dict[action]
if len(action_array) == 4:
# node actions have length 4
action_description = node_action_description(action_array)
elif len(action_array) == 6:
# acl actions have length 6
action_description = acl_action_description(action_array)
else:
# Should never happen
action_description = "Unrecognised action"
return action_description

View File

@@ -32,6 +32,7 @@ class Priority(Enum):
class HardwareState(Enum):
"""Node hardware state enumeration."""
NONE = 0
ON = 1
OFF = 2
RESETTING = 3
@@ -42,6 +43,7 @@ class HardwareState(Enum):
class SoftwareState(Enum):
"""Software or Service state enumeration."""
NONE = 0
GOOD = 1
PATCHING = 2
COMPROMISED = 3
@@ -94,7 +96,8 @@ class VerboseLevel(IntEnum):
class AgentFramework(Enum):
NONE = 0
"""The agent algorithm framework/package."""
CUSTOM = 0
"Custom Agent"
SB3 = 1
"Stable Baselines3"
@@ -103,7 +106,7 @@ class AgentFramework(Enum):
class DeepLearningFramework(Enum):
"""The deep learning framework enumeration."""
"""The deep learning framework."""
TF = "tf"
"Tensorflow"
TF2 = "tf2"
@@ -112,15 +115,28 @@ class DeepLearningFramework(Enum):
"PyTorch"
class RedAgentIdentifier(Enum):
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"Custom Agent"
RANDOM = 4
"Custom Agent"
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
class HardCodedAgentView(Enum):
"""The view the deterministic hard-coded agent has of the environment."""
BASIC = 1
"The current observation space only"
FULL = 2
"Full environment view with actions taken and reward feedback"
class ActionType(Enum):

View File

@@ -1,32 +1,41 @@
# Main Config File
# Training Config File
# Sets which agent algorithm framework will be used:
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "NONE" (Custom Agent)
# "CUSTOM" (Custom Agent)
agent_framework: RLLIB
# Sets which deep learning framework will be used. Default is TF (Tensorflow).
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TORCH
# Sets which Red Agent algo/class will be used:
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic)
# "PPO" (Proximal Policy Optimization)
# "HARDCODED" (Custom Agent)
# "RANDOM" (Random Action)
red_agent_identifier: PPO
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
action_type: ACL
# Number of episodes to run per session
num_episodes: 10

View File

@@ -8,8 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import DeepLearningFramework
from primaite.common.enums import ActionType, RedAgentIdentifier, \
from primaite.common.enums import DeepLearningFramework, HardCodedAgentView
from primaite.common.enums import ActionType, AgentIdentifier, \
AgentFramework, SessionType, OutputVerboseLevel
_LOGGER = getLogger(__name__)
@@ -42,8 +42,11 @@ class TrainingConfig:
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
"The DeepLearningFramework"
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
"The RedAgentIdentifier"
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
"The AgentIdentifier"
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
"The view the deterministic hard-coded agent has of the environment"
action_type: ActionType = ActionType.ANY
"The ActionType to use"
@@ -176,10 +179,11 @@ class TrainingConfig:
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
"red_agent_identifier": RedAgentIdentifier,
"agent_identifier": AgentIdentifier,
"action_type": ActionType,
"session_type": SessionType,
"output_verbose_level": OutputVerboseLevel
"output_verbose_level": OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView
}
for field, enum_class in field_enum_map.items():
@@ -197,12 +201,13 @@ class TrainingConfig:
"""
data = self.__dict__
if json_serializable:
data["agent_framework"] = self.agent_framework.value
data["deep_learning_framework"] = self.deep_learning_framework.value
data["red_agent_identifier"] = self.red_agent_identifier.value
data["action_type"] = self.action_type.value
data["output_verbose_level"] = self.output_verbose_level.value
data["session_type"] = self.session_type.value
data["agent_framework"] = self.agent_framework.name
data["deep_learning_framework"] = self.deep_learning_framework.name
data["agent_identifier"] = self.agent_identifier.name
data["action_type"] = self.action_type.name
data["output_verbose_level"] = self.output_verbose_level.name
data["session_type"] = self.session_type.name
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
return data
@@ -255,7 +260,7 @@ def load(
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
agent_framework: AgentFramework = AgentFramework.SB3,
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256,
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
@@ -266,8 +271,8 @@ def convert_legacy_training_config_dict(
:param legacy_config_dict: A legacy training config dict.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param red_agent_identifier: The red agent identifier to use as legacy
training configs don't have red_agent_identifier values.
:param agent_identifier: The red agent identifier to use as legacy
training configs don't have agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs
@@ -278,7 +283,7 @@ def convert_legacy_training_config_dict(
"""
config_dict = {
"agent_framework": agent_framework.name,
"red_agent_identifier": red_agent_identifier.name,
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps,
"output_verbose_level": output_verbose_level

View File

@@ -97,7 +97,7 @@ class Primaite(Env):
self.transaction_list = transaction_list
# The agent in use
self.agent_identifier = self.training_config.red_agent_identifier
self.agent_identifier = self.training_config.agent_identifier
# Create a dictionary to hold all the nodes
self.nodes: Dict[str, NodeUnion] = {}

View File

@@ -1,137 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""
The main PrimAITE session runner module.
TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_path and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class.
"""
"""The main PrimAITE session runner module."""
import argparse
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Final, Union
from uuid import uuid4
from typing import Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = getLogger(__name__)
def run_generic(env: Primaite, config_values: TrainingConfig):
"""
Run against a generic agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
"""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.close()
def run_stable_baselines3_ppo(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 PPO agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = PPO.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
"""
Persist an agent.
Only works for stable baselines3 agents at present.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if not isinstance(agent, OnPolicyAlgorithm):
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
_LOGGER.error(msg)
else:
filepath = session_path / f"agent_saved_{timestamp_str}"
agent.save(filepath)
_LOGGER.debug(f"Trained agent saved as: {filepath}")
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
"""Run the PrimAITE Session.

View File

@@ -8,9 +8,13 @@ from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR
from primaite.agents.agent import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \
RandomAgent, DummyAgent
from primaite.common.enums import AgentFramework, AgentIdentifier, \
ActionType, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
@@ -68,31 +72,66 @@ class PrimaiteSession:
self.learn()
def setup(self):
if self._training_config.agent_framework == AgentFramework.NONE:
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
# Stochastic Random Agent
raise NotImplementedError
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
if self._training_config.agent_framework == AgentFramework.CUSTOM:
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
raise NotImplementedError
self._agent_session = HardCodedNodeAgent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
raise NotImplementedError
self._agent_session = HardCodedACLAgent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid RedAgentIdentifier ActionType combo
pass
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
self._agent_session = RandomAgent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
self._agent_session = DummyAgent(
self._training_config_path,
self._lay_down_config_path
)
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
# Invalid AgentFramework AgentIdentifier combo
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
# Stable Baselines3 Agent
@@ -110,7 +149,7 @@ class PrimaiteSession:
else:
# Invalid AgentFramework
pass
raise ValueError
def learn(
self,

View File

@@ -13,7 +13,7 @@ agent_framework: RLLIB
# "A2C" (Advantage Actor Critic)
# "HARDCODED" (Custom Agent)
# "RANDOM" (Random Action)
red_agent_identifier: PPO
agent_identifier: PPO
# Sets How the Action Space is defined:
# "NODE"