901 - added max_acl_rules, implicit_acl_rule and apply_implicit rule to main_config, changed observations.py for ACLs to match the action space for ACLs, added position of acl rule to ACL action type
This commit is contained in:
@@ -58,7 +58,7 @@ class TrainingConfig:
|
||||
implicit_acl_rule: str = "DENY"
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rule: int = 0
|
||||
max_number_acl_rules: int = 0
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
@@ -190,6 +190,7 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
|
||||
:raises TypeError: When the TrainingConfig object cannot be created
|
||||
using the values from the config file read from ``file_path``.
|
||||
"""
|
||||
print("FILE PATH", file_path)
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
|
||||
@@ -9,6 +9,7 @@ from gym import spaces
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
Protocol,
|
||||
RulePermissionType,
|
||||
SoftwareState,
|
||||
)
|
||||
@@ -309,11 +310,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
:param acl_implicit_rule: Whether to have an implicit DENY or implicit ALLOW ACL rule at the end of the ACL list
|
||||
Default is 0 DENY, 1 ALLOW
|
||||
:type acl_implicit_rule: ImplicitFirewallRule Enumeration (ALLOW or DENY)
|
||||
:param max_acl_rules: Maximum number of ACLs allowed in the environment
|
||||
:type max_acl_rules: int
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
@@ -334,6 +330,15 @@ class AccessControlList(AbstractObservationComponent):
|
||||
]
|
||||
"""
|
||||
|
||||
# Terms (for ACL observation space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
@@ -377,31 +382,54 @@ class AccessControlList(AbstractObservationComponent):
|
||||
permission_int = 1
|
||||
|
||||
if source_ip == "ANY":
|
||||
source_ip = 0
|
||||
source_ip_int = 0
|
||||
else:
|
||||
source_ip_int = self.obtain_node_id_using_ip(source_ip)
|
||||
if dest_ip == "ANY":
|
||||
dest_ip = 0
|
||||
if port == "ANY":
|
||||
port = 0
|
||||
dest_ip_int = 0
|
||||
else:
|
||||
dest_ip_int = self.obtain_node_id_using_ip(dest_ip)
|
||||
if protocol == "ANY":
|
||||
protocol_int = 0
|
||||
else:
|
||||
while True:
|
||||
if protocol in self.service_dict:
|
||||
protocol_int = self.services_dict[protocol]
|
||||
break
|
||||
else:
|
||||
self.services_dict[protocol] = len(self.services_dict) + 1
|
||||
continue
|
||||
# [0 - DENY, 1 - ALLOW] Permission
|
||||
# [0 - ANY, x - IP Address/Protocol/Port]
|
||||
try:
|
||||
protocol_int = Protocol[protocol]
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
if port == "ANY":
|
||||
port_int = 0
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port)
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
|
||||
print(permission_int, source_ip, dest_ip, protocol_int, port)
|
||||
print(permission_int, source_ip, dest_ip, protocol_int, port_int, position)
|
||||
obs.extend(
|
||||
[permission_int, source_ip, dest_ip, protocol_int, port, position]
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def obtain_node_id_using_ip(self, ip_address):
|
||||
"""Uses IP address of Nodes to find the ID.
|
||||
|
||||
Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space
|
||||
"""
|
||||
for key, node in self.env.nodes:
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
if node.ip_address == ip_address:
|
||||
return key
|
||||
_LOGGER.info(f"Node ID was not found from IP Address {ip_address}")
|
||||
return -1
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
|
||||
@@ -120,9 +120,8 @@ class Primaite(Env):
|
||||
# Create the Access Control List
|
||||
self.acl = AccessControlList(
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
|
||||
@@ -423,14 +422,13 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
len(self.action_dict[_action]) == 7
|
||||
): # ACL actions in multidiscrete form have len 7
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
@@ -981,6 +979,7 @@ class Primaite(Env):
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
0,
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
@@ -1173,6 +1172,10 @@ class Primaite(Env):
|
||||
actions = {0: [0, 0, 0, 0, 0, 0]}
|
||||
|
||||
action_key = 1
|
||||
print(
|
||||
"what is this primaite_env.py 1177",
|
||||
self.training_config.max_number_acl_rules - 1,
|
||||
)
|
||||
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
||||
for action_decision in range(3):
|
||||
# 2 possible action permissions 0 = DENY, 1 = CREATE
|
||||
@@ -1182,7 +1185,9 @@ class Primaite(Env):
|
||||
for dest_ip in range(self.num_nodes + 1):
|
||||
for protocol in range(self.num_services + 1):
|
||||
for port in range(self.num_ports + 1):
|
||||
for position in range(self.max_acl_rules - 1):
|
||||
for position in range(
|
||||
self.training_config.max_number_acl_rules - 1
|
||||
):
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
@@ -1192,10 +1197,11 @@ class Primaite(Env):
|
||||
port,
|
||||
position,
|
||||
]
|
||||
# Check to see if its an action we want to include as possible i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
# Check to see if it is an action we want to include as possible
|
||||
# i.e. not a nothing action
|
||||
if is_valid_acl_action_extra(action):
|
||||
actions[action_key] = action
|
||||
action_key += 1
|
||||
|
||||
return actions
|
||||
|
||||
@@ -1219,4 +1225,5 @@ class Primaite(Env):
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
print("combined dict", combined_action_dict.items())
|
||||
return combined_action_dict
|
||||
|
||||
Reference in New Issue
Block a user