901 - updated observations.py to change and add new mapping of ACL rules to represent no rule present in list
This commit is contained in:
@@ -39,7 +39,9 @@ class AccessControlList:
|
|||||||
"""
|
"""
|
||||||
if self.acl_implicit_rule is not None:
|
if self.acl_implicit_rule is not None:
|
||||||
acl_list = self._acl + [self.acl_implicit_rule]
|
acl_list = self._acl + [self.acl_implicit_rule]
|
||||||
return acl_list + [-1] * (self.max_acl_rules - len(acl_list))
|
else:
|
||||||
|
acl_list = self._acl
|
||||||
|
return acl_list + [None] * (self.max_acl_rules - len(acl_list))
|
||||||
|
|
||||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||||
"""
|
"""
|
||||||
@@ -86,15 +88,18 @@ class AccessControlList:
|
|||||||
Indicates block if all conditions are satisfied.
|
Indicates block if all conditions are satisfied.
|
||||||
"""
|
"""
|
||||||
for rule in self.acl:
|
for rule in self.acl:
|
||||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
if isinstance(rule, ACLRule):
|
||||||
if (
|
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||||
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
|
if (
|
||||||
) and (str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"):
|
rule.get_protocol() == _protocol or rule.get_protocol() == "ANY"
|
||||||
# There's a matching rule. Get the permission
|
) and (
|
||||||
if rule.get_permission() == "DENY":
|
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||||
return True
|
):
|
||||||
elif rule.get_permission() == "ALLOW":
|
# There's a matching rule. Get the permission
|
||||||
return False
|
if rule.get_permission() == "DENY":
|
||||||
|
return True
|
||||||
|
elif rule.get_permission() == "ALLOW":
|
||||||
|
return False
|
||||||
|
|
||||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||||
return True
|
return True
|
||||||
@@ -115,7 +120,6 @@ class AccessControlList:
|
|||||||
"""
|
"""
|
||||||
position_index = int(_position)
|
position_index = int(_position)
|
||||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||||
print(len(self._acl))
|
|
||||||
if len(self._acl) + 1 < self.max_acl_rules:
|
if len(self._acl) + 1 < self.max_acl_rules:
|
||||||
if _position is not None:
|
if _position is not None:
|
||||||
if self.max_acl_rules - 1 > position_index > -1:
|
if self.max_acl_rules - 1 > position_index > -1:
|
||||||
@@ -136,6 +140,7 @@ class AccessControlList:
|
|||||||
f"The ACL list is FULL."
|
f"The ACL list is FULL."
|
||||||
f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}."
|
f"The list of ACLs has length {len(self.acl)} and it has a max capacity of {self.max_acl_rules}."
|
||||||
)
|
)
|
||||||
|
# print("length of this list", len(self._acl))
|
||||||
|
|
||||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -131,7 +131,8 @@ class LinkStatus(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class RulePermissionType(Enum):
|
class RulePermissionType(Enum):
|
||||||
"""Implicit firewall rule."""
|
"""Any firewall rule type."""
|
||||||
|
|
||||||
DENY = 0
|
NA = 0
|
||||||
ALLOW = 1
|
DENY = 1
|
||||||
|
ALLOW = 2
|
||||||
|
|||||||
@@ -5,14 +5,14 @@
|
|||||||
# "STABLE_BASELINES3_PPO"
|
# "STABLE_BASELINES3_PPO"
|
||||||
# "STABLE_BASELINES3_A2C"
|
# "STABLE_BASELINES3_A2C"
|
||||||
# "GENERIC"
|
# "GENERIC"
|
||||||
agent_identifier: STABLE_BASELINES3_A2C
|
agent_identifier: STABLE_BASELINES3_PPO
|
||||||
# Sets How the Action Space is defined:
|
# Sets How the Action Space is defined:
|
||||||
# "NODE"
|
# "NODE"
|
||||||
# "ACL"
|
# "ACL"
|
||||||
# "ANY" node and acl actions
|
# "ANY" node and acl actions
|
||||||
action_type: NODE
|
action_type: ACL
|
||||||
# Number of episodes to run per session
|
# Number of episodes to run per session
|
||||||
num_episodes: 10
|
num_episodes: 1000
|
||||||
# Number of time_steps per episode
|
# Number of time_steps per episode
|
||||||
num_steps: 256
|
num_steps: 256
|
||||||
# Time delay between steps (for generic agents)
|
# Time delay between steps (for generic agents)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from primaite.acl.acl_rule import ACLRule
|
|||||||
from primaite.common.enums import (
|
from primaite.common.enums import (
|
||||||
FileSystemState,
|
FileSystemState,
|
||||||
HardwareState,
|
HardwareState,
|
||||||
Protocol,
|
|
||||||
RulePermissionType,
|
RulePermissionType,
|
||||||
SoftwareState,
|
SoftwareState,
|
||||||
)
|
)
|
||||||
@@ -330,13 +329,14 @@ class AccessControlList(AbstractObservationComponent):
|
|||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
0,
|
||||||
# Terms (for ACL observation space):
|
# Terms (for ACL observation space):
|
||||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
# [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
# [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
|
||||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
# [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to IP addresses)
|
||||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
# [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
# [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
# [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||||
|
|
||||||
_DATA_TYPE: type = np.int64
|
_DATA_TYPE: type = np.int64
|
||||||
|
|
||||||
@@ -346,18 +346,17 @@ class AccessControlList(AbstractObservationComponent):
|
|||||||
# 1. Define the shape of your observation space component
|
# 1. Define the shape of your observation space component
|
||||||
acl_shape = [
|
acl_shape = [
|
||||||
len(RulePermissionType),
|
len(RulePermissionType),
|
||||||
len(env.nodes) + 1,
|
len(env.nodes) + 2,
|
||||||
len(env.nodes) + 1,
|
len(env.nodes) + 2,
|
||||||
len(env.services_list),
|
len(env.services_list) + 1,
|
||||||
len(env.ports_list),
|
len(env.ports_list) + 1,
|
||||||
env.max_number_acl_rules,
|
env.max_number_acl_rules + 1,
|
||||||
]
|
]
|
||||||
# shape = acl_shape
|
|
||||||
shape = acl_shape * self.env.max_number_acl_rules
|
shape = acl_shape * self.env.max_number_acl_rules
|
||||||
|
|
||||||
# 2. Create Observation space
|
# 2. Create Observation space
|
||||||
self.space = spaces.MultiDiscrete(shape)
|
self.space = spaces.MultiDiscrete(shape)
|
||||||
print("obs space:", self.space)
|
# print("obs space:", self.space)
|
||||||
# 3. Initialise observation with zeroes
|
# 3. Initialise observation with zeroes
|
||||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||||
|
|
||||||
@@ -367,8 +366,8 @@ class AccessControlList(AbstractObservationComponent):
|
|||||||
The structure of the observation space is described in :class:`.AccessControlList`
|
The structure of the observation space is described in :class:`.AccessControlList`
|
||||||
"""
|
"""
|
||||||
obs = []
|
obs = []
|
||||||
|
# print("starting len", len(self.env.acl.acl))
|
||||||
for index in range(len(self.env.acl.acl)):
|
for index in range(0, len(self.env.acl.acl)):
|
||||||
acl_rule = self.env.acl.acl[index]
|
acl_rule = self.env.acl.acl[index]
|
||||||
if isinstance(acl_rule, ACLRule):
|
if isinstance(acl_rule, ACLRule):
|
||||||
permission = acl_rule.permission
|
permission = acl_rule.permission
|
||||||
@@ -378,26 +377,25 @@ class AccessControlList(AbstractObservationComponent):
|
|||||||
port = acl_rule.port
|
port = acl_rule.port
|
||||||
position = index
|
position = index
|
||||||
|
|
||||||
source_ip_int = -1
|
source_ip_int = None
|
||||||
dest_ip_int = -1
|
dest_ip_int = None
|
||||||
if permission == "DENY":
|
if permission == "DENY":
|
||||||
permission_int = 0
|
|
||||||
else:
|
|
||||||
permission_int = 1
|
permission_int = 1
|
||||||
|
else:
|
||||||
|
permission_int = 2
|
||||||
if source_ip == "ANY":
|
if source_ip == "ANY":
|
||||||
source_ip_int = 0
|
source_ip_int = 1
|
||||||
else:
|
else:
|
||||||
nodes = list(self.env.nodes.values())
|
nodes = list(self.env.nodes.values())
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
# print(node.ip_address, source_ip, node.ip_address == source_ip)
|
|
||||||
if (
|
if (
|
||||||
isinstance(node, ServiceNode)
|
isinstance(node, ServiceNode)
|
||||||
or isinstance(node, ActiveNode)
|
or isinstance(node, ActiveNode)
|
||||||
) and node.ip_address == source_ip:
|
) and node.ip_address == source_ip:
|
||||||
source_ip_int = node.node_id
|
source_ip_int = int(node.node_id) + 1
|
||||||
break
|
break
|
||||||
if dest_ip == "ANY":
|
if dest_ip == "ANY":
|
||||||
dest_ip_int = 0
|
dest_ip_int = 1
|
||||||
else:
|
else:
|
||||||
nodes = list(self.env.nodes.values())
|
nodes = list(self.env.nodes.values())
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
@@ -405,46 +403,46 @@ class AccessControlList(AbstractObservationComponent):
|
|||||||
isinstance(node, ServiceNode)
|
isinstance(node, ServiceNode)
|
||||||
or isinstance(node, ActiveNode)
|
or isinstance(node, ActiveNode)
|
||||||
) and node.ip_address == dest_ip:
|
) and node.ip_address == dest_ip:
|
||||||
dest_ip_int = node.node_id
|
dest_ip_int = int(node.node_id) + 1
|
||||||
if protocol == "ANY":
|
if protocol == "ANY":
|
||||||
protocol_int = 0
|
protocol_int = 1
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
protocol_int = Protocol[protocol].value
|
protocol_int = self.env.services_list.index(protocol) + 2
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
_LOGGER.info(f"Service {protocol} could not be found")
|
_LOGGER.info(f"Service {protocol} could not be found")
|
||||||
protocol_int = -1
|
protocol_int = None
|
||||||
if port == "ANY":
|
if port == "ANY":
|
||||||
port_int = 0
|
port_int = 1
|
||||||
else:
|
else:
|
||||||
if port in self.env.ports_list:
|
if port in self.env.ports_list:
|
||||||
port_int = self.env.ports_list.index(port)
|
port_int = self.env.ports_list.index(port) + 2
|
||||||
else:
|
else:
|
||||||
_LOGGER.info(f"Port {port} could not be found.")
|
_LOGGER.info(f"Port {port} could not be found.")
|
||||||
|
|
||||||
# Either do the multiply on the obs space
|
# Either do the multiply on the obs space
|
||||||
# Change the obs to
|
# Change the obs to
|
||||||
if source_ip_int != -1 and dest_ip_int != -1:
|
items_to_add = [
|
||||||
items_to_add = [
|
permission_int,
|
||||||
permission_int,
|
source_ip_int,
|
||||||
source_ip_int,
|
dest_ip_int,
|
||||||
dest_ip_int,
|
protocol_int,
|
||||||
protocol_int,
|
port_int,
|
||||||
port_int,
|
position,
|
||||||
position,
|
]
|
||||||
]
|
position = position * 6
|
||||||
position = position * 6
|
for item in items_to_add:
|
||||||
for item in items_to_add:
|
# print("position", position, "\nitem", int(item))
|
||||||
obs.insert(position, int(item))
|
obs.insert(position, int(item))
|
||||||
position += 1
|
position += 1
|
||||||
else:
|
else:
|
||||||
items_to_add = [-1, -1, -1, -1, -1, index]
|
starting_position = index * 6
|
||||||
position = index * 6
|
for placeholder in range(6):
|
||||||
for item in items_to_add:
|
obs.insert(starting_position, 0)
|
||||||
obs.insert(position, int(item))
|
starting_position += 1
|
||||||
position += 1
|
|
||||||
|
|
||||||
self.current_observation = obs
|
# print("current obs", obs, "\n" ,len(obs))
|
||||||
|
self.current_observation[:] = obs
|
||||||
|
|
||||||
|
|
||||||
class ObservationsHandler:
|
class ObservationsHandler:
|
||||||
|
|||||||
@@ -1204,7 +1204,8 @@ class Primaite(Env):
|
|||||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||||
# reserve 0 action to be a nothing action
|
# reserve 0 action to be a nothing action
|
||||||
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
||||||
|
# [1, 1, 2, 1, 1, 1, 2] CREATE RULE ALLOW NODE 2 TO NODE 1 ON SERVICE 1 PORT 1 AT INDEX 2
|
||||||
|
# 1, 2, 1, 6, 0, 0, 1 ALLOW NODE 2 TO NODE 1 ON SERVICE 1 SERVICE ANY PORT ANY AT INDEX 1
|
||||||
action_key = 1
|
action_key = 1
|
||||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||||
for action_decision in range(3):
|
for action_decision in range(3):
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
# Choice whether to have an ALLOW or DENY implicit rule or not (True or False)
|
||||||
apply_implicit_rule: True
|
apply_implicit_rule: True
|
||||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||||
implicit_acl_rule: DENY
|
implicit_acl_rule: DENY
|
||||||
# Total number of ACL rules allowed in the environment
|
# Total number of ACL rules allowed in the environment
|
||||||
max_number_acl_rules: 10
|
max_number_acl_rules: 3
|
||||||
|
|
||||||
observation_space:
|
observation_space:
|
||||||
components:
|
components:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from tests.conftest import _get_primaite_env_from_config
|
|||||||
def run_generic_set_actions(env: Primaite):
|
def run_generic_set_actions(env: Primaite):
|
||||||
"""Run against a generic agent with specified blue agent actions."""
|
"""Run against a generic agent with specified blue agent actions."""
|
||||||
# Reset the environment at the start of the episode
|
# Reset the environment at the start of the episode
|
||||||
# env.reset()
|
env.reset()
|
||||||
training_config = env.training_config
|
training_config = env.training_config
|
||||||
for episode in range(0, training_config.num_episodes):
|
for episode in range(0, training_config.num_episodes):
|
||||||
for step in range(0, training_config.num_steps):
|
for step in range(0, training_config.num_steps):
|
||||||
@@ -31,12 +31,14 @@ def run_generic_set_actions(env: Primaite):
|
|||||||
# [1, 1, 2, 1, 1, 1, 2] ACL Action
|
# [1, 1, 2, 1, 1, 1, 2] ACL Action
|
||||||
# Creates an ACL rule
|
# Creates an ACL rule
|
||||||
# Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2
|
# Allows traffic from SERVER to PC1 on port TCP 80 and place ACL at position 2
|
||||||
action = 291
|
# Rule in current observation: [2, 2, 3, 2, 2, 2]
|
||||||
|
action = 43
|
||||||
elif step == 7:
|
elif step == 7:
|
||||||
# [1, 1, 3, 1, 2, 2, 1] ACL Action
|
# [1, 1, 3, 1, 2, 2, 1] ACL Action
|
||||||
# Creates an ACL rule
|
# Creates an ACL rule
|
||||||
# Allows traffic from PC1 to SWITCH 1 on port UDP at position 1
|
# Allows traffic from PC1 to SWITCH 1 on port UDP at position 1
|
||||||
action = 425
|
# 3, 1, 1, 1, 1,
|
||||||
|
action = 96
|
||||||
# Run the simulation step on the live environment
|
# Run the simulation step on the live environment
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action)
|
||||||
# Update observations space and return
|
# Update observations space and return
|
||||||
@@ -282,7 +284,7 @@ class TestAccessControlList:
|
|||||||
env.update_environent_obs()
|
env.update_environent_obs()
|
||||||
|
|
||||||
# we have two ACLs
|
# we have two ACLs
|
||||||
assert env.env_obs.shape == (5, 2)
|
assert env.env_obs.shape == (6 * 3)
|
||||||
|
|
||||||
def test_values(self, env: Primaite):
|
def test_values(self, env: Primaite):
|
||||||
"""Test that traffic values are encoded correctly.
|
"""Test that traffic values are encoded correctly.
|
||||||
@@ -305,7 +307,7 @@ class TestAccessControlList:
|
|||||||
print(obs)
|
print(obs)
|
||||||
assert np.array_equal(obs, [])
|
assert np.array_equal(obs, [])
|
||||||
|
|
||||||
def test_observation_space(self):
|
def test_observation_space_with_implicit_rule(self):
|
||||||
"""Test observation space is what is expected when an agent adds ACLs during an episode."""
|
"""Test observation space is what is expected when an agent adds ACLs during an episode."""
|
||||||
# Used to use env from test fixture but AtrributeError function object has no 'training_config'
|
# Used to use env from test fixture but AtrributeError function object has no 'training_config'
|
||||||
env = _get_primaite_env_from_config(
|
env = _get_primaite_env_from_config(
|
||||||
@@ -314,5 +316,17 @@ class TestAccessControlList:
|
|||||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown_ACL.yaml",
|
||||||
)
|
)
|
||||||
run_generic_set_actions(env)
|
run_generic_set_actions(env)
|
||||||
|
obs = env.env_obs
|
||||||
|
"""
|
||||||
|
Observation space at the end of the episode.
|
||||||
|
At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0
|
||||||
|
(0 represents its initial position at top of ACL list)
|
||||||
|
On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0
|
||||||
|
On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1
|
||||||
|
THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION
|
||||||
|
"""
|
||||||
|
|
||||||
# print("observation space",env.observation_space)
|
# assert current_obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||||
|
assert np.array_equal(
|
||||||
|
obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user