#901 -
- Added comments in access_control_list.py - Changed obs_shape to max_number_acl_rules from max_number_acl_rules + 1 as index starts from 1 - Commented episode and step print line from test_single_action_space.py
This commit is contained in:
@@ -35,13 +35,14 @@ class AccessControlList:
|
||||
def acl(self):
|
||||
"""Public access method for private _acl.
|
||||
|
||||
Adds implicit rule to end of acl list and
|
||||
Pads out rest of list (if empty) with -1.
|
||||
Adds implicit rule to the BACK of the list after ALL the OTHER ACL rules and
|
||||
pads out rest of list (if it is empty) with None.
|
||||
"""
|
||||
if self.acl_implicit_rule is not None:
|
||||
acl_list = self._acl + [self.acl_implicit_rule]
|
||||
else:
|
||||
acl_list = self._acl
|
||||
|
||||
return acl_list + [None] * (self.max_acl_rules - len(acl_list))
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
@@ -113,13 +114,17 @@ class AccessControlList:
|
||||
return
|
||||
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
@@ -140,7 +145,7 @@ class AccessControlList:
|
||||
"""
|
||||
# Add check so you cant remove implicit rule
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
# There will not always be something removable since the agent will be trying random things
|
||||
try:
|
||||
self.acl.remove(rule)
|
||||
except Exception:
|
||||
|
||||
@@ -408,9 +408,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
@@ -429,6 +426,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
|
||||
@@ -436,27 +434,37 @@ class AccessControlList(AbstractObservationComponent):
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules + 1,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
# print("obs space:", self.space)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
@@ -468,7 +476,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
# print("starting len", len(self.env.acl.acl))
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
@@ -478,7 +486,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == "DENY":
|
||||
@@ -488,6 +496,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
@@ -498,6 +507,8 @@ class AccessControlList(AbstractObservationComponent):
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
@@ -507,6 +518,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
@@ -520,7 +532,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
@@ -533,9 +545,9 @@ class AccessControlList(AbstractObservationComponent):
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
# print("current obs", obs, "\n" ,len(obs))
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
|
||||
@@ -19,7 +19,7 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
print("Episode:", episode, "\nStep:", step)
|
||||
# print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# Creates an ACL rule
|
||||
|
||||
Reference in New Issue
Block a user