#917 -Finished integrating all agents to either train (policy agents) or evaluate (hard-coded agents). Still some fixing up to do, tidying up, loading etc. also docs. But this is all now working.
This commit is contained in:
@@ -29,7 +29,7 @@ The environment config file consists of the following attributes:
|
||||
* SB3 - Stable Baselines3
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **red_agent_identifier**
|
||||
* **agent_identifier**
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
2.0.0b1
|
||||
2.0.0rc1
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -12,7 +13,6 @@ from primaite.config import training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -196,50 +196,77 @@ class AgentSessionABC(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DeterministicAgentSessionABC(AgentSessionABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
obs = self._env.reset()
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
pass
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
|
||||
376
src/primaite/agents/hardcoded_acl.py
Normal file
376
src/primaite/agents/hardcoded_acl.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
transform_action_acl_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
# full view action using observation space, action
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(self, green_iers, acl, nodes):
|
||||
blocked_green_iers = {}
|
||||
|
||||
for green_ier_id, green_ier in green_iers.items():
|
||||
source_node_id = green_ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = green_ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = green_ier.get_protocol() # e.g. 'TCP'
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol,
|
||||
port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get matching ACL rules for an IER.
|
||||
"""
|
||||
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address,
|
||||
dest_node_address, protocol,
|
||||
port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked)
|
||||
"""
|
||||
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
blocked_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
blocked_rules[rule_key] = rule_value
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
"""
|
||||
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol,
|
||||
port, acl,
|
||||
nodes, services_list):
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
else:
|
||||
source_node_address = source_node_id
|
||||
|
||||
if dest_node_id != "ANY":
|
||||
dest_node_address = nodes[str(dest_node_id)].ip_address
|
||||
else:
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[
|
||||
protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address,
|
||||
dest_node_address, protocol,
|
||||
port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol,
|
||||
port, acl,
|
||||
nodes, services_list):
|
||||
matching_rules = self.get_matching_acl_rules(source_node_id,
|
||||
dest_node_id,
|
||||
protocol, port, acl,
|
||||
nodes,
|
||||
services_list)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port,
|
||||
acl,
|
||||
nodes, services_list):
|
||||
matching_rules = self.get_matching_acl_rules(source_node_id,
|
||||
dest_node_id,
|
||||
protocol, port, acl,
|
||||
nodes,
|
||||
services_list)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs):
|
||||
"""
|
||||
Given an observation and the environment calculate a good acl-based action for the blue agent to take
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
- Which ACL rules already exist, - otherwise:
|
||||
- The agent would perminently get stuck in a loop of performing the same action over and over.
|
||||
(best action is to block something, but its already blocked but doesn't know this)
|
||||
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
|
||||
if it doesnt know what rules exist)
|
||||
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
|
||||
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
|
||||
based on the observation space. Additionally, potentially in the future, once a node state
|
||||
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
|
||||
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
|
||||
|
||||
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
|
||||
|
||||
If a service node becomes compromised there's a decision to make - do we block that service?
|
||||
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
|
||||
Cons: Will block a green IER, decreasing the reward
|
||||
We decide to block the service.
|
||||
|
||||
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
"""
|
||||
#obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, _, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# 1. Check if node is compromised. If so we want to block its outwards services
|
||||
# a. If it is comprimised check if there's an allow rule we should delete.
|
||||
# cons: might delete a multi-rule from any source node (ANY -> x)
|
||||
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
|
||||
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
|
||||
found_action = False
|
||||
for service_num, service_states in enumerate(s):
|
||||
for x, service_state in enumerate(service_states):
|
||||
if service_state == "COMPROMISED":
|
||||
|
||||
action_source_id = x + 1 # +1 as 0 is any
|
||||
action_destination_id = "ANY"
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
action_port = "ANY"
|
||||
|
||||
allow_rules = self.get_allow_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
deny_rules = self.get_deny_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
if len(allow_rules) > 0:
|
||||
# Check if there's an allow rule we should delete
|
||||
rule = list(allow_rules.values())[0]
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(
|
||||
get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(
|
||||
get_node_of_ip(action_destination_ip,
|
||||
self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(
|
||||
action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = self._env.ports_list.index(
|
||||
action_port_name) + 1 # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
elif len(deny_rules) > 0:
|
||||
# TODO OPTIONAL
|
||||
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
|
||||
# to create another
|
||||
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
|
||||
continue
|
||||
else:
|
||||
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
|
||||
action_decision = "CREATE"
|
||||
action_permission = "DENY"
|
||||
break
|
||||
if found_action:
|
||||
break
|
||||
|
||||
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
|
||||
# add an Allow rule if the green IER is being blocked.
|
||||
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
|
||||
# If there's a DENY rule delete it if:
|
||||
# - There isn't already a deny rule
|
||||
# - It doesnt allows a comprimised node to become operational.
|
||||
# b. Add an ALLOW rule if:
|
||||
# - There isn't already an allow rule
|
||||
# - It doesnt allows a comprimised node to become operational
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(
|
||||
self._env.green_iers, self._env.acl,
|
||||
self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier,
|
||||
self._env.acl,
|
||||
self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(
|
||||
service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
|
||||
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
|
||||
action_decision = "CREATE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_id = int(ier.get_source_node_id())
|
||||
action_destination_id = int(ier.get_dest_node_id())
|
||||
action_protocol_name = ier.get_protocol()
|
||||
action_protocol = self._env.services_list.index(
|
||||
action_protocol_name) + 1 # convert name e.g. 'TCP' to index
|
||||
action_port_name = ier.get_port()
|
||||
action_port = self._env.ports_list.index(
|
||||
action_port_name) + 1 # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
|
||||
if found_action:
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
else:
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs):
|
||||
"""Given an observation calculate a good acl-based action for the blue agent to take
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge of previous actions taken and
|
||||
NO reward feedback.
|
||||
|
||||
We rely on randomness to select the precise action, as we want to block all traffic originating from
|
||||
a compromised node, without being able to tell:
|
||||
1. Which ACL rules already exist
|
||||
1. Which actions the agent has already tried.
|
||||
|
||||
There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed.
|
||||
|
||||
Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted.
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len(
|
||||
[i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if
|
||||
i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = np.random.choice(
|
||||
comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(
|
||||
list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = int(
|
||||
action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
# Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port
|
||||
action_port = np.random.choice(list(range(1, len(s) + 1)))
|
||||
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_ip,
|
||||
action_destination_ip,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, action_dict)
|
||||
return nothing_action
|
||||
97
src/primaite/agents/hardcoded_node.py
Normal file
97
src/primaite/agents/hardcoded_node.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.agents.utils import (
|
||||
transform_action_node_enum,
|
||||
)
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
def _calculate_action(self, obs):
|
||||
"""Given an observation calculate a good node-based action for the blue agent to take"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, os, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# Check in order of most important states (order doesn't currently matter, but it probably should)
|
||||
# First see if any OS states are compromised
|
||||
for x, os_state in enumerate(os):
|
||||
if os_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OS"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = 0 # does nothing isn't relevant for os
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, see if any Services are compromised
|
||||
# We fix the compromised state before overwhelemd state,
|
||||
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, See if any services are overwhelmed
|
||||
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
|
||||
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
|
||||
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "OVERWHELMED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Finally, turn on any off nodes
|
||||
for x, operating_state in enumerate(o):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [action_node_id, action_node_property,
|
||||
property_action, action_service_index]
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good actions, just go with an action that wont do any harm
|
||||
action_node_id = 1
|
||||
action_node_property = "NONE"
|
||||
property_action = "ON"
|
||||
action_service_index = 0
|
||||
action = [action_node_id, action_node_property, property_action,
|
||||
action_service_index]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
|
||||
return action
|
||||
@@ -1,21 +1,19 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, RedAgentIdentifier
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
def _env_creator(env_config):
|
||||
@@ -51,13 +49,13 @@ class RLlibAgent(AgentSessionABC):
|
||||
f"got {self._training_config.agent_framework}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = ("Expected PPO or A2C red_agent_identifier, "
|
||||
f"got {self._training_config.red_agent_identifier.value}")
|
||||
msg = ("Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: PPOConfig
|
||||
@@ -67,8 +65,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"red_agent_identifier="
|
||||
f"{self._training_config.red_agent_identifier}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}, "
|
||||
f"deep_learning_framework="
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
@@ -117,7 +115,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._agent_config.framework(
|
||||
framework=self._training_config.deep_learning_framework
|
||||
framework="torch"
|
||||
)
|
||||
|
||||
self._agent_config.rollouts(
|
||||
|
||||
@@ -2,12 +2,12 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO, A2C
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import RedAgentIdentifier, AgentFramework
|
||||
from primaite.common.enums import AgentIdentifier, AgentFramework
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -24,13 +24,13 @@ class SB3Agent(AgentSessionABC):
|
||||
f"got {self._training_config.agent_framework}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = ("Expected PPO or A2C red_agent_identifier, "
|
||||
f"got {self._training_config.red_agent_identifier.value}")
|
||||
msg = ("Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -40,8 +40,8 @@ class SB3Agent(AgentSessionABC):
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"red_agent_identifier="
|
||||
f"{self._training_config.red_agent_identifier}"
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}"
|
||||
)
|
||||
|
||||
def _setup(self):
|
||||
@@ -56,7 +56,7 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self._training_config.output_verbose_level,
|
||||
verbose=self.output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path
|
||||
)
|
||||
@@ -118,6 +118,7 @@ class SB3Agent(AgentSessionABC):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
@classmethod
|
||||
def load(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
60
src/primaite/agents/simple.py
Normal file
60
src/primaite/agents/simple.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_action_acl_enum,
|
||||
transform_action_node_enum,
|
||||
)
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Random Agent.
|
||||
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
class DummyAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Dummy Agent.
|
||||
|
||||
All action spaces setup so dummy action is always 0 regardless of action
|
||||
type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
return 0
|
||||
|
||||
|
||||
class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing ACL agent.
|
||||
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
return nothing_action
|
||||
|
||||
|
||||
class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing Node agent.
|
||||
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
# nothing_action should currently always be 0
|
||||
|
||||
return nothing_action
|
||||
@@ -1,4 +1,13 @@
|
||||
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
|
||||
import numpy as np
|
||||
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
LinkStatus,
|
||||
NodeHardwareAction,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
@@ -125,3 +134,393 @@ def is_valid_acl_action_extra(action):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs):
|
||||
"""Transform list of transactions to readable list of each observation property
|
||||
|
||||
example:
|
||||
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
os_states = [SoftwareState(i).name for i in obs[:, 2]]
|
||||
new_obs = [ids, operating_states, os_states]
|
||||
|
||||
for service in range(3, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def transform_obs_readable(obs):
|
||||
"""
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
"""
|
||||
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
# Convert list of tuples to list of lists
|
||||
new_obs = [list(i) for i in new_obs]
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_new_obs(obs, num_nodes=10):
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Convert to old observation, links filled with 0's as no information is included in new observation space
|
||||
|
||||
example:
|
||||
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
|
||||
|
||||
new_obs = array([[ 1, 1, 1, 1],
|
||||
[ 2, 1, 1, 1],
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
"""
|
||||
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
|
||||
# Add empty links back and add node ID back
|
||||
s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64)
|
||||
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
|
||||
s[:num_nodes, 1:] = reshaped_nodes # put values back in
|
||||
new_obs = s
|
||||
|
||||
# Add links back in
|
||||
links = obs[-num_links:]
|
||||
# Links will be added to the last protocol/service slot but they are not specific to that service
|
||||
new_obs[num_nodes:, -1] = links
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""Return string describing change between two observations
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
list_of_changes = []
|
||||
for n, row in enumerate(obs1 - obs2):
|
||||
if row.any() != 0:
|
||||
relevant_changes = np.where(row != 0, obs2[n], -1)
|
||||
relevant_changes[0] = obs2[n, 0] # ID is always relevant
|
||||
is_link = relevant_changes[0] > num_nodes
|
||||
desc = _describe_obs_change_helper(relevant_changes, is_link)
|
||||
list_of_changes.append(desc)
|
||||
|
||||
change_string = "\n ".join(list_of_changes)
|
||||
if len(list_of_changes) > 0:
|
||||
change_string = "\n " + change_string
|
||||
return change_string
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change, is_link):
|
||||
""" "
|
||||
Helper funcion to describe what has changed
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
|
||||
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
|
||||
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [
|
||||
NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed
|
||||
]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
if is_link
|
||||
else HardwareState(obs_change[i]).name
|
||||
if i == 1
|
||||
else SoftwareState(obs_change[i]).name
|
||||
for i in index_changed
|
||||
]
|
||||
|
||||
if not is_link:
|
||||
desc = f"ID {obs_change[0]}:"
|
||||
for NodePOLType, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + NodePOLType + " changed to " + state + "."
|
||||
else:
|
||||
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
def transform_action_node_enum(action):
|
||||
"""
|
||||
Convert a node action from readable string format, to enumerated format
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
"""
|
||||
|
||||
action_node_id = action[0]
|
||||
action_node_property = NodePOLType[action[1]].value
|
||||
|
||||
if action[1] == "OPERATING":
|
||||
property_action = NodeHardwareAction[action[2]].value
|
||||
elif action[1] == "OS" or action[1] == "SERVICE":
|
||||
property_action = NodeSoftwareAction[action[2]].value
|
||||
else:
|
||||
property_action = 0
|
||||
|
||||
action_service_index = action[3]
|
||||
|
||||
new_action = [action_node_id, action_node_property, property_action, action_service_index]
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
"""
|
||||
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def node_action_description(action):
|
||||
"""
|
||||
Generate string describing a node-based action
|
||||
"""
|
||||
|
||||
if isinstance(action[1], (int, np.int64)):
|
||||
# transform action to readable format
|
||||
action = transform_action_node_readable(action)
|
||||
|
||||
node_id = action[0]
|
||||
node_property = action[1]
|
||||
property_action = action[2]
|
||||
service_id = action[3]
|
||||
|
||||
if property_action == "NONE":
|
||||
return ""
|
||||
if node_property == "OPERATING" or node_property == "OS":
|
||||
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
elif node_property == "SERVICE":
|
||||
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
return description
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
"""
|
||||
Transform an ACL action to a more readable format
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
"""
|
||||
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_enum(action):
|
||||
"""
|
||||
Convert a acl action from readable string format, to enumerated format
|
||||
"""
|
||||
|
||||
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
|
||||
action_permissions = {"DENY": 0, "ALLOW": 1}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == "ANY":
|
||||
new_action[n + 2] = 0
|
||||
|
||||
new_action = np.array(new_action)
|
||||
return new_action
|
||||
|
||||
|
||||
def acl_action_description(action):
|
||||
"""generate string describing a acl-based action"""
|
||||
|
||||
if isinstance(action[0], (int, np.int64)):
|
||||
# transform action to readable format
|
||||
action = transform_action_acl_readable(action)
|
||||
if action[0] == "NONE":
|
||||
description = "NO ACL RULE APPLIED"
|
||||
else:
|
||||
description = (
|
||||
f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]},"
|
||||
f" for protocol/service index {action[4]} on port index {action[5]}"
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
|
||||
def get_node_of_ip(ip, node_dict):
|
||||
"""
|
||||
Get the node ID of an IP address
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
"""
|
||||
|
||||
for node_key, node_value in node_dict.items():
|
||||
node_ip = node_value.ip_address
|
||||
if node_ip == ip:
|
||||
return node_key
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_new_action(old_action, action_dict):
|
||||
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0]
|
||||
|
||||
old_action can be either node or acl action type
|
||||
"""
|
||||
|
||||
for key, val in action_dict.items():
|
||||
if list(val) == list(old_action):
|
||||
return key
|
||||
# Not all possible actions are included in dict, only valid action are
|
||||
# if action is not in the dict, its an invalid action so return 0
|
||||
return 0
|
||||
|
||||
|
||||
def get_action_description(action, action_dict):
|
||||
"""
|
||||
Get a string describing/explaining what an action is doing in words
|
||||
"""
|
||||
|
||||
action_array = action_dict[action]
|
||||
if len(action_array) == 4:
|
||||
# node actions have length 4
|
||||
action_description = node_action_description(action_array)
|
||||
elif len(action_array) == 6:
|
||||
# acl actions have length 6
|
||||
action_description = acl_action_description(action_array)
|
||||
else:
|
||||
# Should never happen
|
||||
action_description = "Unrecognised action"
|
||||
|
||||
return action_description
|
||||
|
||||
@@ -32,6 +32,7 @@ class Priority(Enum):
|
||||
class HardwareState(Enum):
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
@@ -42,6 +43,7 @@ class HardwareState(Enum):
|
||||
class SoftwareState(Enum):
|
||||
"""Software or Service state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
@@ -94,7 +96,8 @@ class VerboseLevel(IntEnum):
|
||||
|
||||
|
||||
class AgentFramework(Enum):
|
||||
NONE = 0
|
||||
"""The agent algorithm framework/package."""
|
||||
CUSTOM = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
@@ -103,7 +106,7 @@ class AgentFramework(Enum):
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework enumeration."""
|
||||
"""The deep learning framework."""
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
@@ -112,15 +115,28 @@ class DeepLearningFramework(Enum):
|
||||
"PyTorch"
|
||||
|
||||
|
||||
class RedAgentIdentifier(Enum):
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"Custom Agent"
|
||||
RANDOM = 4
|
||||
"Custom Agent"
|
||||
"The Hardcoded agents"
|
||||
DO_NOTHING = 4
|
||||
"The DoNothing agents"
|
||||
RANDOM = 5
|
||||
"The RandomAgent"
|
||||
DUMMY = 6
|
||||
"The DummyAgent"
|
||||
|
||||
|
||||
class HardCodedAgentView(Enum):
|
||||
"""The view the deterministic hard-coded agent has of the environment."""
|
||||
BASIC = 1
|
||||
"The current observation space only"
|
||||
FULL = 2
|
||||
"Full environment view with actions taken and reward feedback"
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
||||
@@ -1,32 +1,41 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used:
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "NONE" (Custom Agent)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: RLLIB
|
||||
|
||||
# Sets which deep learning framework will be used. Default is TF (Tensorflow).
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TORCH
|
||||
|
||||
# Sets which Red Agent algo/class will be used:
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic)
|
||||
# "PPO" (Proximal Policy Optimization)
|
||||
# "HARDCODED" (Custom Agent)
|
||||
# "RANDOM" (Random Action)
|
||||
red_agent_identifier: PPO
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ACL
|
||||
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
|
||||
@@ -8,8 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import DeepLearningFramework
|
||||
from primaite.common.enums import ActionType, RedAgentIdentifier, \
|
||||
from primaite.common.enums import DeepLearningFramework, HardCodedAgentView
|
||||
from primaite.common.enums import ActionType, AgentIdentifier, \
|
||||
AgentFramework, SessionType, OutputVerboseLevel
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -42,8 +42,11 @@ class TrainingConfig:
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework"
|
||||
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
|
||||
"The RedAgentIdentifier"
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
|
||||
"The AgentIdentifier"
|
||||
|
||||
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
|
||||
"The view the deterministic hard-coded agent has of the environment"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use"
|
||||
@@ -176,10 +179,11 @@ class TrainingConfig:
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"red_agent_identifier": RedAgentIdentifier,
|
||||
"agent_identifier": AgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"output_verbose_level": OutputVerboseLevel
|
||||
"output_verbose_level": OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView
|
||||
}
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
@@ -197,12 +201,13 @@ class TrainingConfig:
|
||||
"""
|
||||
data = self.__dict__
|
||||
if json_serializable:
|
||||
data["agent_framework"] = self.agent_framework.value
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.value
|
||||
data["red_agent_identifier"] = self.red_agent_identifier.value
|
||||
data["action_type"] = self.action_type.value
|
||||
data["output_verbose_level"] = self.output_verbose_level.value
|
||||
data["session_type"] = self.session_type.value
|
||||
data["agent_framework"] = self.agent_framework.name
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.name
|
||||
data["agent_identifier"] = self.agent_identifier.name
|
||||
data["action_type"] = self.action_type.name
|
||||
data["output_verbose_level"] = self.output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
|
||||
return data
|
||||
|
||||
@@ -255,7 +260,7 @@ def load(
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256,
|
||||
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
|
||||
@@ -266,8 +271,8 @@ def convert_legacy_training_config_dict(
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param red_agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have red_agent_identifier values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
@@ -278,7 +283,7 @@ def convert_legacy_training_config_dict(
|
||||
"""
|
||||
config_dict = {
|
||||
"agent_framework": agent_framework.name,
|
||||
"red_agent_identifier": red_agent_identifier.name,
|
||||
"agent_identifier": agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_steps": num_steps,
|
||||
"output_verbose_level": output_verbose_level
|
||||
|
||||
@@ -97,7 +97,7 @@ class Primaite(Env):
|
||||
self.transaction_list = transaction_list
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.red_agent_identifier
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
|
||||
# Create a dictionary to hold all the nodes
|
||||
self.nodes: Dict[str, NodeUnion] = {}
|
||||
|
||||
@@ -1,137 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The main PrimAITE session runner module.
|
||||
|
||||
TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_path and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Union
|
||||
from uuid import uuid4
|
||||
from typing import Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
"""
|
||||
Run against a generic agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
"""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo(
|
||||
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(config_values.num_episodes):
|
||||
agent.learn(total_timesteps=config_values.num_steps)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
|
||||
|
||||
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Persist an agent.
|
||||
|
||||
Only works for stable baselines3 agents at present.
|
||||
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if not isinstance(agent, OnPolicyAlgorithm):
|
||||
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
_LOGGER.error(msg)
|
||||
else:
|
||||
filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
agent.save(filepath)
|
||||
_LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
|
||||
|
||||
|
||||
|
||||
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
|
||||
@@ -8,9 +8,13 @@ from uuid import uuid4
|
||||
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \
|
||||
RandomAgent, DummyAgent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, \
|
||||
ActionType, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
@@ -68,31 +72,66 @@ class PrimaiteSession:
|
||||
self.learn()
|
||||
|
||||
def setup(self):
|
||||
if self._training_config.agent_framework == AgentFramework.NONE:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
|
||||
# Stochastic Random Agent
|
||||
raise NotImplementedError
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
raise NotImplementedError
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
raise NotImplementedError
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid RedAgentIdentifier ActionType combo
|
||||
pass
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
# Stable Baselines3 Agent
|
||||
@@ -110,7 +149,7 @@ class PrimaiteSession:
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
pass
|
||||
raise ValueError
|
||||
|
||||
def learn(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@ agent_framework: RLLIB
|
||||
# "A2C" (Advantage Actor Critic)
|
||||
# "HARDCODED" (Custom Agent)
|
||||
# "RANDOM" (Random Action)
|
||||
red_agent_identifier: PPO
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
|
||||
Reference in New Issue
Block a user