Merge remote-tracking branch 'origin/dev' into feature/1468-observations-class
This commit is contained in:
0
src/primaite/agents/__init__.py
Normal file
0
src/primaite/agents/__init__.py
Normal file
127
src/primaite/agents/utils.py
Normal file
127
src/primaite/agents/utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
|
||||
2
|
||||
] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
# print("node property", node_property, "\nnode action", node_action)
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action."""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -49,6 +49,7 @@ class SoftwareState(Enum):
|
||||
class NodePOLType(Enum):
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
NONE = 0
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
@@ -81,6 +82,7 @@ class ActionType(Enum):
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
|
||||
|
||||
class ObservationType(Enum):
|
||||
@@ -98,3 +100,29 @@ class FileSystemState(Enum):
|
||||
DESTROYED = 3
|
||||
REPAIRING = 4
|
||||
RESTORING = 5
|
||||
|
||||
|
||||
class NodeHardwareAction(Enum):
|
||||
"""Node hardware action."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESET = 3
|
||||
|
||||
|
||||
class NodeSoftwareAction(Enum):
|
||||
"""Node software action."""
|
||||
|
||||
NONE = 0
|
||||
PATCHING = 1
|
||||
|
||||
|
||||
class LinkStatus(Enum):
|
||||
"""Link traffic status."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
@@ -15,6 +15,7 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -42,6 +43,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -205,15 +207,9 @@ class Primaite(Env):
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.num_nodes,
|
||||
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
|
||||
self.ACTION_SPACE_NODE_ACTION_VALUES,
|
||||
self.num_services,
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
@@ -222,17 +218,14 @@ class Primaite(Env):
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.ACTION_SPACE_ACL_ACTION_VALUES,
|
||||
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
|
||||
self.num_nodes + 1,
|
||||
self.num_nodes + 1,
|
||||
self.num_services + 1,
|
||||
self.num_ports + 1,
|
||||
]
|
||||
)
|
||||
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info("Invalid action type selected")
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
@@ -378,7 +371,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.config_values,
|
||||
)
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -435,8 +428,18 @@ class Primaite(Env):
|
||||
# At the moment, actions are only affecting nodes
|
||||
if self.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
elif self.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""
|
||||
@@ -445,10 +448,11 @@ class Primaite(Env):
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
node_id = _action[0]
|
||||
node_property = _action[1]
|
||||
property_action = _action[2]
|
||||
service_index = _action[3]
|
||||
readable_action = self.action_dict[_action]
|
||||
node_id = readable_action[0]
|
||||
node_property = readable_action[1]
|
||||
property_action = readable_action[2]
|
||||
service_index = readable_action[3]
|
||||
|
||||
# Check that the action is requesting a valid node
|
||||
try:
|
||||
@@ -534,12 +538,15 @@ class Primaite(Env):
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
action_decision = _action[0]
|
||||
action_permission = _action[1]
|
||||
action_source_ip = _action[2]
|
||||
action_destination_ip = _action[3]
|
||||
action_protocol = _action[4]
|
||||
action_port = _action[5]
|
||||
# Convert discrete value back to multidiscrete
|
||||
readable_action = self.action_dict[_action]
|
||||
|
||||
action_decision = readable_action[0]
|
||||
action_permission = readable_action[1]
|
||||
action_source_ip = readable_action[2]
|
||||
action_destination_ip = readable_action[3]
|
||||
action_protocol = readable_action[4]
|
||||
action_port = readable_action[5]
|
||||
|
||||
if action_decision == 0:
|
||||
# It's decided to do nothing
|
||||
@@ -1119,3 +1126,91 @@ class Primaite(Env):
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
|
||||
def create_node_action_dict(self):
|
||||
"""
|
||||
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
|
||||
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
||||
|
||||
example return:
|
||||
{0: [1, 0, 0, 0],
|
||||
1: [1, 1, 1, 0],
|
||||
2: [1, 1, 2, 0],
|
||||
3: [1, 1, 3, 0],
|
||||
4: [1, 2, 1, 0],
|
||||
5: [1, 3, 1, 0],
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [1, 0, 0, 0]}
|
||||
action_key = 1
|
||||
for node in range(1, self.num_nodes + 1):
|
||||
# 4 node properties (NONE, OPERATING, OS, SERVICE)
|
||||
for node_property in range(4):
|
||||
# Node Actions either:
|
||||
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
|
||||
# Use MAX to ensure we get them all
|
||||
for node_action in range(4):
|
||||
for service_state in range(self.num_services):
|
||||
action = [node, node_property, node_action, service_state]
|
||||
# check to see if it's a nothing action (has no effect)
|
||||
if is_valid_node_action(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
def create_acl_action_dict(self):
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
for action_decision in range(3):
|
||||
# 2 possible action permissions 0 = DENY, 1 = CREATE
|
||||
for action_permission in range(2):
|
||||
# Number of nodes + 1 (for any)
|
||||
for source_ip in range(self.num_nodes + 1):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
source_ip,
|
||||
dest_ip,
|
||||
protocol,
|
||||
port,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
def create_node_and_acl_action_dict(self):
|
||||
"""
|
||||
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
|
||||
The dictionary contains actions of both Node and ACL action types.
|
||||
|
||||
"""
|
||||
node_action_dict = self.create_node_action_dict()
|
||||
acl_action_dict = self.create_acl_action_dict()
|
||||
|
||||
# Change node keys to not overlap with acl keys
|
||||
# Only 1 nothing action (key 0) is required, remove the other
|
||||
new_node_action_dict = {
|
||||
k + len(acl_action_dict) - 1: v
|
||||
for k, v in node_action_dict.items()
|
||||
if k != 0
|
||||
}
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
return combined_action_dict
|
||||
|
||||
Reference in New Issue
Block a user