901 - updated observations.py to change and add new mapping of ACL rules to represent no rule present in list

This commit is contained in:
SunilSamra
2023-07-05 09:08:03 +01:00
parent 41aed12f27
commit 3f440c0a28
7 changed files with 95 additions and 76 deletions

View File

@@ -39,7 +39,9 @@ class AccessControlList:
"""
if self.acl_implicit_rule is not None:
acl_list = self._acl + [self.acl_implicit_rule]
return acl_list + [-1] * (self.max_acl_rules - len(acl_list))
else:
acl_list = self._acl
return acl_list + [None] * (self.max_acl_rules - len(acl_list))
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
@@ -86,15 +88,18 @@ class AccessControlList:
Indicates block if all conditions are satisfied.
"""
for rule in self.acl:
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
) and (str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"):
# There's a matching rule. Get the permission
if rule.get_permission() == "DENY":
return True
elif rule.get_permission() == "ALLOW":
return False
if isinstance(rule, ACLRule):
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
) and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule.get_permission() == "DENY":
return True
elif rule.get_permission() == "ALLOW":
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
@@ -115,7 +120,6 @@ class AccessControlList:
"""
position_index = int(_position)
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
print(len(self._acl))
if len(self._acl) + 1 < self.max_acl_rules:
if _position is not None:
if self.max_acl_rules - 1 > position_index > -1:
@@ -136,6 +140,7 @@ class AccessControlList:
f"The ACL list is FULL."
f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}."
)
# print("length of this list", len(self._acl))
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""

View File

@@ -131,7 +131,8 @@ class LinkStatus(Enum):
class RulePermissionType(Enum):
"""Implicit firewall rule."""
"""Any firewall rule type."""
DENY = 0
ALLOW = 1
NA = 0
DENY = 1
ALLOW = 2

View File

@@ -5,14 +5,14 @@
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
agent_identifier: STABLE_BASELINES3_PPO
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
action_type: ACL
# Number of episodes to run per session
num_episodes: 10
num_episodes: 1000
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)

View File

@@ -10,7 +10,6 @@ from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import (
FileSystemState,
HardwareState,
Protocol,
RulePermissionType,
SoftwareState,
)
@@ -330,13 +329,14 @@ class AccessControlList(AbstractObservationComponent):
]
"""
0,
# Terms (for ACL observation space):
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
# [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
# [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
# [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
# [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
# [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
_DATA_TYPE: type = np.int64
@@ -346,18 +346,17 @@ class AccessControlList(AbstractObservationComponent):
# 1. Define the shape of your observation space component
acl_shape = [
len(RulePermissionType),
len(env.nodes) + 1,
len(env.nodes) + 1,
len(env.services_list),
len(env.ports_list),
env.max_number_acl_rules,
len(env.nodes) + 2,
len(env.nodes) + 2,
len(env.services_list) + 1,
len(env.ports_list) + 1,
env.max_number_acl_rules + 1,
]
# shape = acl_shape
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
print("obs space:", self.space)
# print("obs space:", self.space)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
@@ -367,8 +366,8 @@ class AccessControlList(AbstractObservationComponent):
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
for index in range(len(self.env.acl.acl)):
# print("starting len", len(self.env.acl.acl))
for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission
@@ -378,26 +377,25 @@ class AccessControlList(AbstractObservationComponent):
port = acl_rule.port
position = index
source_ip_int = -1
dest_ip_int = -1
source_ip_int = None
dest_ip_int = None
if permission == "DENY":
permission_int = 0
else:
permission_int = 1
else:
permission_int = 2
if source_ip == "ANY":
source_ip_int = 0
source_ip_int = 1
else:
nodes = list(self.env.nodes.values())
for node in nodes:
# print(node.ip_address, source_ip, node.ip_address == source_ip)
if (
isinstance(node, ServiceNode)
or isinstance(node, ActiveNode)
) and node.ip_address == source_ip:
source_ip_int = node.node_id
source_ip_int = int(node.node_id) + 1
break
if dest_ip == "ANY":
dest_ip_int = 0
dest_ip_int = 1
else:
nodes = list(self.env.nodes.values())
for node in nodes:
@@ -405,46 +403,46 @@ class AccessControlList(AbstractObservationComponent):
isinstance(node, ServiceNode)
or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip:
dest_ip_int = node.node_id
dest_ip_int = int(node.node_id) + 1
if protocol == "ANY":
protocol_int = 0
protocol_int = 1
else:
try:
protocol_int = Protocol[protocol].value
protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = -1
protocol_int = None
if port == "ANY":
port_int = 0
port_int = 1
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port)
port_int = self.env.ports_list.index(port) + 2
else:
_LOGGER.info(f"Port {port} could not be found.")
# Either do the multiply on the obs space
# Change the obs to
if source_ip_int != -1 and dest_ip_int != -1:
items_to_add = [
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
position = position * 6
for item in items_to_add:
obs.insert(position, int(item))
position += 1
else:
items_to_add = [-1, -1, -1, -1, -1, index]
position = index * 6
for item in items_to_add:
obs.insert(position, int(item))
position += 1
items_to_add = [
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
position = position * 6
for item in items_to_add:
# print("position", position, "\nitem", int(item))
obs.insert(position, int(item))
position += 1
else:
starting_position = index * 6
for placeholder in range(6):
obs.insert(starting_position, 0)
starting_position += 1
self.current_observation = obs
# print("current obs", obs, "\n" ,len(obs))
self.current_observation[:] = obs
class ObservationsHandler:

View File

@@ -1204,7 +1204,8 @@ class Primaite(Env):
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
# [1, 1, 2, 1, 1, 1, 2] CREATE RULE ALLOW NODE 2 TO NODE 1 ON SERVICE 1 PORT 1 AT INDEX 2
# 1, 2, 1, 6, 0, 0, 1 ALLOW NODE 2 TO NODE 1 ON SERVICE 1 SERVICE ANY PORT ANY AT INDEX 1
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):

View File

@@ -26,12 +26,12 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
# Choice whether to have an ALLOW or DENY implicit rule or not (True or False)
apply_implicit_rule: True
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 10
max_number_acl_rules: 3
observation_space:
components:

View File

@@ -18,7 +18,7 @@ from tests.conftest import _get_primaite_env_from_config
def run_generic_set_actions(env: Primaite):
"""Run against a generic agent with specified blue agent actions."""
# Reset the environment at the start of the episode
# env.reset()
env.reset()
training_config = env.training_config
for episode in range(0, training_config.num_episodes):
for step in range(0, training_config.num_steps):
@@ -31,12 +31,14 @@ def run_generic_set_actions(env: Primaite):
# [1, 1, 2, 1, 1, 1, 2] ACL Action
# Creates an ACL rule
# Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2
action = 291
# Rule in current observation: [2, 2, 3, 2, 2, 2]
action = 43
elif step == 7:
# [1, 1, 3, 1, 2, 2, 1] ACL Action
# Creates an ACL rule
# Allows traffic from PC1 to SWITCH 1 on port UDP at position 1
action = 425
# 3, 1, 1, 1, 1,
action = 96
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Update observations space and return
@@ -282,7 +284,7 @@ class TestAccessControlList:
env.update_environent_obs()
# we have two ACLs
assert env.env_obs.shape == (5, 2)
assert env.env_obs.shape == (6 * 3)
def test_values(self, env: Primaite):
"""Test that traffic values are encoded correctly.
@@ -305,7 +307,7 @@ class TestAccessControlList:
print(obs)
assert np.array_equal(obs, [])
def test_observation_space(self):
def test_observation_space_with_implicit_rule(self):
"""Test observation space is what is expected when an agent adds ACLs during an episode."""
# Used to use env from test fixture but AtrributeError function object has no 'training_config'
env = _get_primaite_env_from_config(
@@ -314,5 +316,17 @@ class TestAccessControlList:
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
)
run_generic_set_actions(env)
obs = env.env_obs
"""
Observation space at the end of the episode.
At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0
(0 represents its initial position at top of ACL list)
On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0
On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1
THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION
"""
# print("observation space",env.observation_space)
# assert current_obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
assert np.array_equal(
obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
)