901 - updated observations.py to change and add new mapping of ACL rules to represent no rule present in list

This commit is contained in:
SunilSamra
2023-07-05 09:08:03 +01:00
parent 41aed12f27
commit 3f440c0a28
7 changed files with 95 additions and 76 deletions

View File

@@ -39,7 +39,9 @@ class AccessControlList:
""" """
if self.acl_implicit_rule is not None: if self.acl_implicit_rule is not None:
acl_list = self._acl + [self.acl_implicit_rule] acl_list = self._acl + [self.acl_implicit_rule]
return acl_list + [-1] * (self.max_acl_rules - len(acl_list)) else:
acl_list = self._acl
return acl_list + [None] * (self.max_acl_rules - len(acl_list))
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
""" """
@@ -86,15 +88,18 @@ class AccessControlList:
Indicates block if all conditions are satisfied. Indicates block if all conditions are satisfied.
""" """
for rule in self.acl: for rule in self.acl:
if self.check_address_match(rule, _source_ip_address, _dest_ip_address): if isinstance(rule, ACLRule):
if ( if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" if (
) and (str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"): rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
# There's a matching rule. Get the permission ) and (
if rule.get_permission() == "DENY": str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
return True ):
elif rule.get_permission() == "ALLOW": # There's a matching rule. Get the permission
return False if rule.get_permission() == "DENY":
return True
elif rule.get_permission() == "ALLOW":
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default # If there has been no rule to allow the IER through, it will return a blocked signal by default
return True return True
@@ -115,7 +120,6 @@ class AccessControlList:
""" """
position_index = int(_position) position_index = int(_position)
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
print(len(self._acl))
if len(self._acl) + 1 < self.max_acl_rules: if len(self._acl) + 1 < self.max_acl_rules:
if _position is not None: if _position is not None:
if self.max_acl_rules - 1 > position_index > -1: if self.max_acl_rules - 1 > position_index > -1:
@@ -136,6 +140,7 @@ class AccessControlList:
f"The ACL list is FULL." f"The ACL list is FULL."
f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}." f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}."
) )
# print("length of this list", len(self._acl))
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
""" """

View File

@@ -131,7 +131,8 @@ class LinkStatus(Enum):
class RulePermissionType(Enum): class RulePermissionType(Enum):
"""Implicit firewall rule.""" """Any firewall rule type."""
DENY = 0 NA = 0
ALLOW = 1 DENY = 1
ALLOW = 2

View File

@@ -5,14 +5,14 @@
# "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C" # "STABLE_BASELINES3_A2C"
# "GENERIC" # "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C agent_identifier: STABLE_BASELINES3_PPO
# Sets How the Action Space is defined: # Sets How the Action Space is defined:
# "NODE" # "NODE"
# "ACL" # "ACL"
# "ANY" node and acl actions # "ANY" node and acl actions
action_type: NODE action_type: ACL
# Number of episodes to run per session # Number of episodes to run per session
num_episodes: 10 num_episodes: 1000
# Number of time_steps per episode # Number of time_steps per episode
num_steps: 256 num_steps: 256
# Time delay between steps (for generic agents) # Time delay between steps (for generic agents)

View File

@@ -10,7 +10,6 @@ from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import ( from primaite.common.enums import (
FileSystemState, FileSystemState,
HardwareState, HardwareState,
Protocol,
RulePermissionType, RulePermissionType,
SoftwareState, SoftwareState,
) )
@@ -330,13 +329,14 @@ class AccessControlList(AbstractObservationComponent):
] ]
""" """
0,
# Terms (for ACL observation space): # Terms (for ACL observation space):
# [0, 1] - Permission (0 = DENY, 1 = ALLOW) # [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port) # [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list) # [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
_DATA_TYPE: type = np.int64 _DATA_TYPE: type = np.int64
@@ -346,18 +346,17 @@ class AccessControlList(AbstractObservationComponent):
# 1. Define the shape of your observation space component # 1. Define the shape of your observation space component
acl_shape = [ acl_shape = [
len(RulePermissionType), len(RulePermissionType),
len(env.nodes) + 1, len(env.nodes) + 2,
len(env.nodes) + 1, len(env.nodes) + 2,
len(env.services_list), len(env.services_list) + 1,
len(env.ports_list), len(env.ports_list) + 1,
env.max_number_acl_rules, env.max_number_acl_rules + 1,
] ]
# shape = acl_shape
shape = acl_shape * self.env.max_number_acl_rules shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space # 2. Create Observation space
self.space = spaces.MultiDiscrete(shape) self.space = spaces.MultiDiscrete(shape)
print("obs space:", self.space) # print("obs space:", self.space)
# 3. Initialise observation with zeroes # 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
@@ -367,8 +366,8 @@ class AccessControlList(AbstractObservationComponent):
The structure of the observation space is described in :class:`.AccessControlList` The structure of the observation space is described in :class:`.AccessControlList`
""" """
obs = [] obs = []
# print("starting len", len(self.env.acl.acl))
for index in range(len(self.env.acl.acl)): for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index] acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule): if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission permission = acl_rule.permission
@@ -378,26 +377,25 @@ class AccessControlList(AbstractObservationComponent):
port = acl_rule.port port = acl_rule.port
position = index position = index
source_ip_int = -1 source_ip_int = None
dest_ip_int = -1 dest_ip_int = None
if permission == "DENY": if permission == "DENY":
permission_int = 0
else:
permission_int = 1 permission_int = 1
else:
permission_int = 2
if source_ip == "ANY": if source_ip == "ANY":
source_ip_int = 0 source_ip_int = 1
else: else:
nodes = list(self.env.nodes.values()) nodes = list(self.env.nodes.values())
for node in nodes: for node in nodes:
# print(node.ip_address, source_ip, node.ip_address == source_ip)
if ( if (
isinstance(node, ServiceNode) isinstance(node, ServiceNode)
or isinstance(node, ActiveNode) or isinstance(node, ActiveNode)
) and node.ip_address == source_ip: ) and node.ip_address == source_ip:
source_ip_int = node.node_id source_ip_int = int(node.node_id) + 1
break break
if dest_ip == "ANY": if dest_ip == "ANY":
dest_ip_int = 0 dest_ip_int = 1
else: else:
nodes = list(self.env.nodes.values()) nodes = list(self.env.nodes.values())
for node in nodes: for node in nodes:
@@ -405,46 +403,46 @@ class AccessControlList(AbstractObservationComponent):
isinstance(node, ServiceNode) isinstance(node, ServiceNode)
or isinstance(node, ActiveNode) or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip: ) and node.ip_address == dest_ip:
dest_ip_int = node.node_id dest_ip_int = int(node.node_id) + 1
if protocol == "ANY": if protocol == "ANY":
protocol_int = 0 protocol_int = 1
else: else:
try: try:
protocol_int = Protocol[protocol].value protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError: except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found") _LOGGER.info(f"Service {protocol} could not be found")
protocol_int = -1 protocol_int = None
if port == "ANY": if port == "ANY":
port_int = 0 port_int = 1
else: else:
if port in self.env.ports_list: if port in self.env.ports_list:
port_int = self.env.ports_list.index(port) port_int = self.env.ports_list.index(port) + 2
else: else:
_LOGGER.info(f"Port {port} could not be found.") _LOGGER.info(f"Port {port} could not be found.")
# Either do the multiply on the obs space # Either do the multiply on the obs space
# Change the obs to # Change the obs to
if source_ip_int != -1 and dest_ip_int != -1: items_to_add = [
items_to_add = [ permission_int,
permission_int, source_ip_int,
source_ip_int, dest_ip_int,
dest_ip_int, protocol_int,
protocol_int, port_int,
port_int, position,
position, ]
] position = position * 6
position = position * 6 for item in items_to_add:
for item in items_to_add: # print("position", position, "\nitem", int(item))
obs.insert(position, int(item)) obs.insert(position, int(item))
position += 1 position += 1
else: else:
items_to_add = [-1, -1, -1, -1, -1, index] starting_position = index * 6
position = index * 6 for placeholder in range(6):
for item in items_to_add: obs.insert(starting_position, 0)
obs.insert(position, int(item)) starting_position += 1
position += 1
self.current_observation = obs # print("current obs", obs, "\n" ,len(obs))
self.current_observation[:] = obs
class ObservationsHandler: class ObservationsHandler:

View File

@@ -1204,7 +1204,8 @@ class Primaite(Env):
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list) # [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
# reserve 0 action to be a nothing action # reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0, 0]} actions = {0: [0, 0, 0, 0, 0, 0, 0]}
# [1, 1, 2, 1, 1, 1, 2] CREATE RULE ALLOW NODE 2 TO NODE 1 ON SERVICE 1 PORT 1 AT INDEX 2
# 1, 2, 1, 6, 0, 0, 1 ALLOW NODE 2 TO NODE 1 ON SERVICE 1 SERVICE ANY PORT ANY AT INDEX 1
action_key = 1 action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3): for action_decision in range(3):

View File

@@ -26,12 +26,12 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE) # Choice whether to have an ALLOW or DENY implicit rule or not (True or False)
apply_implicit_rule: True apply_implicit_rule: True
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY) # Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment # Total number of ACL rules allowed in the environment
max_number_acl_rules: 10 max_number_acl_rules: 3
observation_space: observation_space:
components: components:

View File

@@ -18,7 +18,7 @@ from tests.conftest import _get_primaite_env_from_config
def run_generic_set_actions(env: Primaite): def run_generic_set_actions(env: Primaite):
"""Run against a generic agent with specified blue agent actions.""" """Run against a generic agent with specified blue agent actions."""
# Reset the environment at the start of the episode # Reset the environment at the start of the episode
# env.reset() env.reset()
training_config = env.training_config training_config = env.training_config
for episode in range(0, training_config.num_episodes): for episode in range(0, training_config.num_episodes):
for step in range(0, training_config.num_steps): for step in range(0, training_config.num_steps):
@@ -31,12 +31,14 @@ def run_generic_set_actions(env: Primaite):
# [1, 1, 2, 1, 1, 1, 2] ACL Action # [1, 1, 2, 1, 1, 1, 2] ACL Action
# Creates an ACL rule # Creates an ACL rule
# Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2 # Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2
action = 291 # Rule in current observation: [2, 2, 3, 2, 2, 2]
action = 43
elif step == 7: elif step == 7:
# [1, 1, 3, 1, 2, 2, 1] ACL Action # [1, 1, 3, 1, 2, 2, 1] ACL Action
# Creates an ACL rule # Creates an ACL rule
# Allows traffic from PC1 to SWITCH 1 on port UDP at position 1 # Allows traffic from PC1 to SWITCH 1 on port UDP at position 1
action = 425 # 3, 1, 1, 1, 1,
action = 96
# Run the simulation step on the live environment # Run the simulation step on the live environment
obs, reward, done, info = env.step(action) obs, reward, done, info = env.step(action)
# Update observations space and return # Update observations space and return
@@ -282,7 +284,7 @@ class TestAccessControlList:
env.update_environent_obs() env.update_environent_obs()
# we have two ACLs # we have two ACLs
assert env.env_obs.shape == (5, 2) assert env.env_obs.shape == (6 * 3)
def test_values(self, env: Primaite): def test_values(self, env: Primaite):
"""Test that traffic values are encoded correctly. """Test that traffic values are encoded correctly.
@@ -305,7 +307,7 @@ class TestAccessControlList:
print(obs) print(obs)
assert np.array_equal(obs, []) assert np.array_equal(obs, [])
def test_observation_space(self): def test_observation_space_with_implicit_rule(self):
"""Test observation space is what is expected when an agent adds ACLs during an episode.""" """Test observation space is what is expected when an agent adds ACLs during an episode."""
# Used to use env from test fixture but AtrributeError function object has no 'training_config' # Used to use env from test fixture but AtrributeError function object has no 'training_config'
env = _get_primaite_env_from_config( env = _get_primaite_env_from_config(
@@ -314,5 +316,17 @@ class TestAccessControlList:
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
) )
run_generic_set_actions(env) run_generic_set_actions(env)
obs = env.env_obs
"""
Observation space at the end of the episode.
At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0
(0 represents its initial position at top of ACL list)
On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0
On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1
THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION
"""
# print("observation space",env.observation_space) # assert current_obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
assert np.array_equal(
obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
)