901 - updated observations.py to change and add new mapping of ACL rules to represent no rule present in list
This commit is contained in:
@@ -39,7 +39,9 @@ class AccessControlList:
|
||||
"""
|
||||
if self.acl_implicit_rule is not None:
|
||||
acl_list = self._acl + [self.acl_implicit_rule]
|
||||
return acl_list + [-1] * (self.max_acl_rules - len(acl_list))
|
||||
else:
|
||||
acl_list = self._acl
|
||||
return acl_list + [None] * (self.max_acl_rules - len(acl_list))
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
@@ -86,15 +88,18 @@ class AccessControlList:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (
|
||||
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
|
||||
) and (str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == "DENY":
|
||||
return True
|
||||
elif rule.get_permission() == "ALLOW":
|
||||
return False
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (
|
||||
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
|
||||
) and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == "DENY":
|
||||
return True
|
||||
elif rule.get_permission() == "ALLOW":
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
@@ -115,7 +120,6 @@ class AccessControlList:
|
||||
"""
|
||||
position_index = int(_position)
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
print(len(self._acl))
|
||||
if len(self._acl) + 1 < self.max_acl_rules:
|
||||
if _position is not None:
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
@@ -136,6 +140,7 @@ class AccessControlList:
|
||||
f"The ACL list is FULL."
|
||||
f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}."
|
||||
)
|
||||
# print("length of this list", len(self._acl))
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
|
||||
@@ -131,7 +131,8 @@ class LinkStatus(Enum):
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Implicit firewall rule."""
|
||||
"""Any firewall rule type."""
|
||||
|
||||
DENY = 0
|
||||
ALLOW = 1
|
||||
NA = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
agent_identifier: STABLE_BASELINES3_PPO
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ACL
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
num_episodes: 1000
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
|
||||
@@ -10,7 +10,6 @@ from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
Protocol,
|
||||
RulePermissionType,
|
||||
SoftwareState,
|
||||
)
|
||||
@@ -330,13 +329,14 @@ class AccessControlList(AbstractObservationComponent):
|
||||
]
|
||||
"""
|
||||
|
||||
0,
|
||||
# Terms (for ACL observation space):
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
# [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
@@ -346,18 +346,17 @@ class AccessControlList(AbstractObservationComponent):
|
||||
# 1. Define the shape of your observation space component
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 1,
|
||||
len(env.nodes) + 1,
|
||||
len(env.services_list),
|
||||
len(env.ports_list),
|
||||
env.max_number_acl_rules,
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 1,
|
||||
len(env.ports_list) + 1,
|
||||
env.max_number_acl_rules + 1,
|
||||
]
|
||||
# shape = acl_shape
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
print("obs space:", self.space)
|
||||
# print("obs space:", self.space)
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
@@ -367,8 +366,8 @@ class AccessControlList(AbstractObservationComponent):
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(len(self.env.acl.acl)):
|
||||
# print("starting len", len(self.env.acl.acl))
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
@@ -378,26 +377,25 @@ class AccessControlList(AbstractObservationComponent):
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
|
||||
source_ip_int = -1
|
||||
dest_ip_int = -1
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == "DENY":
|
||||
permission_int = 0
|
||||
else:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 0
|
||||
source_ip_int = 1
|
||||
else:
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
# print(node.ip_address, source_ip, node.ip_address == source_ip)
|
||||
if (
|
||||
isinstance(node, ServiceNode)
|
||||
or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = node.node_id
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 0
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
@@ -405,46 +403,46 @@ class AccessControlList(AbstractObservationComponent):
|
||||
isinstance(node, ServiceNode)
|
||||
or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = node.node_id
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 0
|
||||
protocol_int = 1
|
||||
else:
|
||||
try:
|
||||
protocol_int = Protocol[protocol].value
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = -1
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 0
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port)
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
|
||||
# Either do the multiply on the obs space
|
||||
# Change the obs to
|
||||
if source_ip_int != -1 and dest_ip_int != -1:
|
||||
items_to_add = [
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
position = position * 6
|
||||
for item in items_to_add:
|
||||
obs.insert(position, int(item))
|
||||
position += 1
|
||||
else:
|
||||
items_to_add = [-1, -1, -1, -1, -1, index]
|
||||
position = index * 6
|
||||
for item in items_to_add:
|
||||
obs.insert(position, int(item))
|
||||
position += 1
|
||||
items_to_add = [
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
position = position * 6
|
||||
for item in items_to_add:
|
||||
# print("position", position, "\nitem", int(item))
|
||||
obs.insert(position, int(item))
|
||||
position += 1
|
||||
else:
|
||||
starting_position = index * 6
|
||||
for placeholder in range(6):
|
||||
obs.insert(starting_position, 0)
|
||||
starting_position += 1
|
||||
|
||||
self.current_observation = obs
|
||||
# print("current obs", obs, "\n" ,len(obs))
|
||||
self.current_observation[:] = obs
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
|
||||
@@ -1204,7 +1204,8 @@ class Primaite(Env):
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
||||
|
||||
# [1, 1, 2, 1, 1, 1, 2] CREATE RULE ALLOW NODE 2 TO NODE 1 ON SERVICE 1 PORT 1 AT INDEX 2
|
||||
# 1, 2, 1, 6, 0, 0, 1 ALLOW NODE 2 TO NODE 1 ON SERVICE 1 SERVICE ANY PORT ANY AT INDEX 1
|
||||
action_key = 1
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
for action_decision in range(3):
|
||||
|
||||
@@ -26,12 +26,12 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (True or False)
|
||||
apply_implicit_rule: True
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 10
|
||||
max_number_acl_rules: 3
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
|
||||
@@ -18,7 +18,7 @@ from tests.conftest import _get_primaite_env_from_config
|
||||
def run_generic_set_actions(env: Primaite):
|
||||
"""Run against a generic agent with specified blue agent actions."""
|
||||
# Reset the environment at the start of the episode
|
||||
# env.reset()
|
||||
env.reset()
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_episodes):
|
||||
for step in range(0, training_config.num_steps):
|
||||
@@ -31,12 +31,14 @@ def run_generic_set_actions(env: Primaite):
|
||||
# [1, 1, 2, 1, 1, 1, 2] ACL Action
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2
|
||||
action = 291
|
||||
# Rule in current observation: [2, 2, 3, 2, 2, 2]
|
||||
action = 43
|
||||
elif step == 7:
|
||||
# [1, 1, 3, 1, 2, 2, 1] ACL Action
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from PC1 to SWITCH 1 on port UDP at position 1
|
||||
action = 425
|
||||
# 3, 1, 1, 1, 1,
|
||||
action = 96
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
# Update observations space and return
|
||||
@@ -282,7 +284,7 @@ class TestAccessControlList:
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have two ACLs
|
||||
assert env.env_obs.shape == (5, 2)
|
||||
assert env.env_obs.shape == (6 * 3)
|
||||
|
||||
def test_values(self, env: Primaite):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
@@ -305,7 +307,7 @@ class TestAccessControlList:
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [])
|
||||
|
||||
def test_observation_space(self):
|
||||
def test_observation_space_with_implicit_rule(self):
|
||||
"""Test observation space is what is expected when an agent adds ACLs during an episode."""
|
||||
# Used to use env from test fixture but AtrributeError function object has no 'training_config'
|
||||
env = _get_primaite_env_from_config(
|
||||
@@ -314,5 +316,17 @@ class TestAccessControlList:
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
||||
)
|
||||
run_generic_set_actions(env)
|
||||
obs = env.env_obs
|
||||
"""
|
||||
Observation space at the end of the episode.
|
||||
At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0
|
||||
(0 represents its initial position at top of ACL list)
|
||||
On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0
|
||||
On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1
|
||||
THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION
|
||||
"""
|
||||
|
||||
# print("observation space",env.observation_space)
|
||||
# assert current_obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
assert np.array_equal(
|
||||
obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user