901 - added max_acl_rules, implicit_acl_rule and apply_implicit rule to main_config, changed observations.py for ACLs to match the action space for ACLs, added position of acl rule to ACL action type
This commit is contained in:
@@ -58,7 +58,7 @@ class TrainingConfig:
|
||||
implicit_acl_rule: str = "DENY"
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rule: int = 0
|
||||
max_number_acl_rules: int = 0
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
@@ -190,6 +190,7 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
|
||||
:raises TypeError: When the TrainingConfig object cannot be created
|
||||
using the values from the config file read from ``file_path``.
|
||||
"""
|
||||
print("FILE PATH", file_path)
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
|
||||
@@ -9,6 +9,7 @@ from gym import spaces
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
Protocol,
|
||||
RulePermissionType,
|
||||
SoftwareState,
|
||||
)
|
||||
@@ -309,11 +310,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
:param acl_implicit_rule: Whether to have an implicit DENY or implicit ALLOW ACL rule at the end of the ACL list
|
||||
Default is 0 DENY, 1 ALLOW
|
||||
:type acl_implicit_rule: ImplicitFirewallRule Enumeration (ALLOW or DENY)
|
||||
:param max_acl_rules: Maximum number of ACLs allowed in the environment
|
||||
:type max_acl_rules: int
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
@@ -334,6 +330,15 @@ class AccessControlList(AbstractObservationComponent):
|
||||
]
|
||||
"""
|
||||
|
||||
# Terms (for ACL observation space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
@@ -377,31 +382,54 @@ class AccessControlList(AbstractObservationComponent):
|
||||
permission_int = 1
|
||||
|
||||
if source_ip == "ANY":
|
||||
source_ip = 0
|
||||
source_ip_int = 0
|
||||
else:
|
||||
source_ip_int = self.obtain_node_id_using_ip(source_ip)
|
||||
if dest_ip == "ANY":
|
||||
dest_ip = 0
|
||||
if port == "ANY":
|
||||
port = 0
|
||||
dest_ip_int = 0
|
||||
else:
|
||||
dest_ip_int = self.obtain_node_id_using_ip(dest_ip)
|
||||
if protocol == "ANY":
|
||||
protocol_int = 0
|
||||
else:
|
||||
while True:
|
||||
if protocol in self.service_dict:
|
||||
protocol_int = self.services_dict[protocol]
|
||||
break
|
||||
else:
|
||||
self.services_dict[protocol] = len(self.services_dict) + 1
|
||||
continue
|
||||
# [0 - DENY, 1 - ALLOW] Permission
|
||||
# [0 - ANY, x - IP Address/Protocol/Port]
|
||||
try:
|
||||
protocol_int = Protocol[protocol]
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
if port == "ANY":
|
||||
port_int = 0
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port)
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
|
||||
print(permission_int, source_ip, dest_ip, protocol_int, port)
|
||||
print(permission_int, source_ip, dest_ip, protocol_int, port_int, position)
|
||||
obs.extend(
|
||||
[permission_int, source_ip, dest_ip, protocol_int, port, position]
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def obtain_node_id_using_ip(self, ip_address):
|
||||
"""Uses IP address of Nodes to find the ID.
|
||||
|
||||
Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space
|
||||
"""
|
||||
for key, node in self.env.nodes:
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
if node.ip_address == ip_address:
|
||||
return key
|
||||
_LOGGER.info(f"Node ID was not found from IP Address {ip_address}")
|
||||
return -1
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
|
||||
@@ -120,9 +120,8 @@ class Primaite(Env):
|
||||
# Create the Access Control List
|
||||
self.acl = AccessControlList(
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
|
||||
@@ -423,14 +422,13 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
len(self.action_dict[_action]) == 7
|
||||
): # ACL actions in multidiscrete form have len 7
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
@@ -981,6 +979,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
0,
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
@@ -1173,6 +1172,10 @@ class Primaite(Env):
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
print(
|
||||
"what is this primaite_env.py 1177",
|
||||
self.training_config.max_number_acl_rules - 1,
|
||||
)
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
for action_decision in range(3):
|
||||
# 2 possible action permissions 0 = DENY, 1 = CREATE
|
||||
@@ -1182,7 +1185,9 @@ class Primaite(Env):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
for position in range(self.max_acl_rules - 1):
|
||||
for position in range(
|
||||
self.training_config.max_number_acl_rules - 1
|
||||
):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
@@ -1192,10 +1197,11 @@ class Primaite(Env):
|
||||
port,
|
||||
position,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
# Check to see if it is an action we want to include as possible
|
||||
# i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
@@ -1219,4 +1225,5 @@ class Primaite(Env):
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
print("combined dict", combined_action_dict.items())
|
||||
return combined_action_dict
|
||||
|
||||
@@ -31,8 +31,6 @@ observation_space_high_value: 1_000_000_000
|
||||
apply_implicit_rule: True
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 10
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -27,7 +27,8 @@ def env(request):
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite):
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -140,7 +142,8 @@ class TestNodeLinkTable:
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -217,4 +220,5 @@ class TestLinkTrafficLevels:
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
@@ -19,15 +19,15 @@ def run_generic_set_actions(env: Primaite):
|
||||
action = 0
|
||||
print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from server_1 to node_1 on port FTP
|
||||
action = 7
|
||||
action = 56
|
||||
elif step == 7:
|
||||
# [1, 1, 2, 0] Node Action
|
||||
# Sets Node 1 Hardware State to OFF
|
||||
# Does not resolve any service
|
||||
action = 16
|
||||
action = 128
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
@@ -48,7 +48,8 @@ def test_single_action_space_is_valid():
|
||||
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
|
||||
run_generic_set_actions(env)
|
||||
@@ -77,8 +78,10 @@ def test_single_action_space_is_valid():
|
||||
def test_agent_is_executing_actions_from_both_spaces():
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
|
||||
Reference in New Issue
Block a user