Merge remote-tracking branch 'origin/dev' into feature/1468-observations-class

This commit is contained in:
Marek Wolan
2023-06-09 09:01:54 +01:00
11 changed files with 643 additions and 47 deletions

View File

@@ -15,6 +15,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -42,6 +43,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
class Primaite(Env):
@@ -205,15 +207,9 @@ class Primaite(Env):
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_space = spaces.MultiDiscrete(
[
self.num_nodes,
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
self.ACTION_SPACE_NODE_ACTION_VALUES,
self.num_services,
]
)
else:
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
@@ -222,17 +218,14 @@ class Primaite(Env):
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_space = spaces.MultiDiscrete(
[
self.ACTION_SPACE_ACL_ACTION_VALUES,
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
self.num_nodes + 1,
self.num_nodes + 1,
self.num_services + 1,
self.num_ports + 1,
]
)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info("Invalid action type selected")
# Set up a csv to store the results of the training
try:
now = datetime.now() # current date and time
@@ -378,7 +371,7 @@ class Primaite(Env):
self.step_count,
self.config_values,
)
# print(f" Step {self.step_count} Reward: {str(reward)}")
print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -435,8 +428,18 @@ class Primaite(Env):
# At the moment, actions are only affecting nodes
if self.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
else:
elif self.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
"""
@@ -445,10 +448,11 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
node_id = _action[0]
node_property = _action[1]
property_action = _action[2]
service_index = _action[3]
readable_action = self.action_dict[_action]
node_id = readable_action[0]
node_property = readable_action[1]
property_action = readable_action[2]
service_index = readable_action[3]
# Check that the action is requesting a valid node
try:
@@ -534,12 +538,15 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
action_decision = _action[0]
action_permission = _action[1]
action_source_ip = _action[2]
action_destination_ip = _action[3]
action_protocol = _action[4]
action_port = _action[5]
# Convert discrete value back to multidiscrete
readable_action = self.action_dict[_action]
action_decision = readable_action[0]
action_permission = readable_action[1]
action_source_ip = readable_action[2]
action_destination_ip = readable_action[3]
action_protocol = readable_action[4]
action_port = readable_action[5]
if action_decision == 0:
# It's decided to do nothing
@@ -1119,3 +1126,91 @@ class Primaite(Env):
else:
# Bad formatting
pass
def create_node_action_dict(self):
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
example return:
{0: [1, 0, 0, 0],
1: [1, 1, 1, 0],
2: [1, 1, 2, 0],
3: [1, 1, 3, 0],
4: [1, 2, 1, 0],
5: [1, 3, 1, 0],
...
}
"""
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
action_key = 1
for node in range(1, self.num_nodes + 1):
# 4 node properties (NONE, OPERATING, OS, SERVICE)
for node_property in range(4):
# Node Actions either:
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action, service_state]
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
actions[action_key] = action
action_key += 1
return actions
def create_acl_action_dict(self):
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):
# 2 possible action permissions 0 = DENY, 1 = CREATE
for action_permission in range(2):
# Number of nodes + 1 (for any)
for source_ip in range(self.num_nodes + 1):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
]
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
return actions
def create_node_and_acl_action_dict(self):
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
The dictionary contains actions of both Node and ACL action types.
"""
node_action_dict = self.create_node_action_dict()
acl_action_dict = self.create_acl_action_dict()
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict