- Added comments in access_control_list.py
- Changed obs_shape to max_number_acl_rules from max_number_acl_rules + 1 as index starts from 1
- Commented episode and step print line from test_single_action_space.py
This commit is contained in:
SunilSamra
2023-07-14 15:27:37 +01:00
parent eb75d15722
commit 661c865108
3 changed files with 30 additions and 13 deletions

View File

@@ -35,13 +35,14 @@ class AccessControlList:
def acl(self):
"""Public access method for private _acl.
Adds implicit rule to end of acl list and
Pads out rest of list (if empty) with -1.
Adds implicit rule to the BACK of the list after ALL the OTHER ACL rules and
pads out rest of list (if it is empty) with None.
"""
if self.acl_implicit_rule is not None:
acl_list = self._acl + [self.acl_implicit_rule]
else:
acl_list = self._acl
return acl_list + [None] * (self.max_acl_rules - len(acl_list))
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
@@ -113,13 +114,17 @@ class AccessControlList:
return
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
# Checks position is in correct range
if self.max_acl_rules - 1 > position_index > -1:
try:
_LOGGER.info(f"Position {position_index} is valid.")
# Check to see Agent will not overwrite current ACL in ACL list
if self._acl[position_index] is None:
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
# Adds rule
self._acl[position_index] = new_rule
else:
# Cannot overwrite it
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
return
except Exception:
@@ -140,7 +145,7 @@ class AccessControlList:
"""
# Add check so you cant remove implicit rule
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
# There will not always be something 'popable' since the agent will be trying random things
# There will not always be something removable since the agent will be trying random things
try:
self.acl.remove(rule)
except Exception:

View File

@@ -408,9 +408,6 @@ class AccessControlList(AbstractObservationComponent):
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
:param env: The environment that forms the basis of the observations
:type env: Primaite
Each ACL Rule has 6 elements. It will have the following structure:
.. code-block::
[
@@ -429,6 +426,7 @@ class AccessControlList(AbstractObservationComponent):
...
]
Terms (for ACL Observation Space):
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
@@ -436,27 +434,37 @@ class AccessControlList(AbstractObservationComponent):
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""
Initialise an AccessControlList observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
# Number of ACL rules incremented by 1 for positions starting at index 0.
acl_shape = [
len(RulePermissionType),
len(env.nodes) + 2,
len(env.nodes) + 2,
len(env.services_list) + 2,
len(env.ports_list) + 2,
env.max_number_acl_rules + 1,
env.max_number_acl_rules,
]
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# print("obs space:", self.space)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
@@ -468,7 +476,7 @@ class AccessControlList(AbstractObservationComponent):
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
# print("starting len", len(self.env.acl.acl))
for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
@@ -478,7 +486,7 @@ class AccessControlList(AbstractObservationComponent):
protocol = acl_rule.protocol
port = acl_rule.port
position = index
# Map each ACL attribute from what it was to an integer to fit the observation space
source_ip_int = None
dest_ip_int = None
if permission == "DENY":
@@ -488,6 +496,7 @@ class AccessControlList(AbstractObservationComponent):
if source_ip == "ANY":
source_ip_int = 1
else:
# Map Node ID (+ 1) to source IP address
nodes = list(self.env.nodes.values())
for node in nodes:
if (
@@ -498,6 +507,8 @@ class AccessControlList(AbstractObservationComponent):
if dest_ip == "ANY":
dest_ip_int = 1
else:
# Map Node ID (+ 1) to dest IP address
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
nodes = list(self.env.nodes.values())
for node in nodes:
if (
@@ -507,6 +518,7 @@ class AccessControlList(AbstractObservationComponent):
if protocol == "ANY":
protocol_int = 1
else:
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
try:
protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError:
@@ -520,7 +532,7 @@ class AccessControlList(AbstractObservationComponent):
else:
_LOGGER.info(f"Port {port} could not be found.")
port_int = None
# Add to current obs
obs.extend(
[
permission_int,
@@ -533,9 +545,9 @@ class AccessControlList(AbstractObservationComponent):
)
else:
# The Nothing or NA representation of 'NONE' ACL rules
obs.extend([0, 0, 0, 0, 0, 0])
# print("current obs", obs, "\n" ,len(obs))
self.current_observation[:] = obs
def generate_structure(self):

View File

@@ -19,7 +19,7 @@ def run_generic_set_actions(env: Primaite):
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
print("Episode:", episode, "\nStep:", step)
# print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1, 1(position)]
# Creates an ACL rule