Files
PrimAITE/src/primaite/agents/hardcoded_acl.py

516 lines
22 KiB
Python
Raw Normal View History

# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
2023-07-13 12:25:54 +01:00
from typing import Dict, List, Union
2023-07-05 16:19:43 +01:00
import numpy as np
2023-07-05 16:19:43 +01:00
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView
2023-07-05 16:19:43 +01:00
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
2023-07-05 16:19:43 +01:00
def _calculate_action(self, obs: np.ndarray) -> int:
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
# full view action using observation space, action
# history and reward feedback
return self._calculate_action_full_view(obs)
2023-07-05 16:19:43 +01:00
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
2023-07-13 12:25:54 +01:00
) -> Dict[str, IER]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
:type green_iers: Dict[str, IER]
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
:rtype: Dict[str, IER]
"""
blocked_green_iers = {}
for green_ier_id, green_ier in green_iers.items():
source_node_id = green_ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = green_ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = green_ier.get_protocol() # e.g. 'TCP'
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
2023-07-13 12:25:54 +01:00
def get_matching_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
2023-07-05 16:19:43 +01:00
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
2023-07-05 16:19:43 +01:00
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
2023-07-13 12:25:54 +01:00
) -> Dict[int, ACLRule]:
"""
Get blocking ACL rules for an IER.
.. warning::
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
blocked_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
blocked_rules[rule_key] = rule_value
return blocked_rules
2023-07-05 16:19:43 +01:00
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
2023-07-13 12:25:54 +01:00
) -> Dict[int, ACLRule]:
2023-07-05 16:19:43 +01:00
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_matching_acl_rules(
self,
2023-07-05 16:19:43 +01:00
source_node_id: str,
dest_node_id: str,
protocol: str,
port: str,
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
2023-07-13 12:25:54 +01:00
) -> Dict[int, ACLRule]:
2023-07-05 16:19:43 +01:00
"""Filter ACL rules to only those which are relevant to the specified nodes.
:param source_node_id: Source node
:type source_node_id: str
:param dest_node_id: Destination nodes
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: str
:param port: Network port
:type port: str
:param acl: Access Control list which will be filtered
:type acl: AccessControlList
:param nodes: The environment's node directory.
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
:param services_list: List of services registered for the environment.
:type services_list: List[str]
:return: Filtered version of 'acl'
:rtype: Dict[str, ACLRule]
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
else:
source_node_address = source_node_id
if dest_node_id != "ANY":
dest_node_address = nodes[str(dest_node_id)].ip_address
else:
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
2023-07-17 16:22:07 +01:00
# TODO: This should throw an error because protocol is a string
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
self,
2023-07-05 16:19:43 +01:00
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
2023-07-13 12:25:54 +01:00
) -> Dict[int, ACLRule]:
2023-07-05 16:19:43 +01:00
"""List ALLOW rules relating to specified nodes.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_deny_acl_rules(
self,
2023-07-05 16:19:43 +01:00
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
2023-07-13 12:25:54 +01:00
) -> Dict[int, ACLRule]:
2023-07-05 16:19:43 +01:00
"""List DENY rules relating to specified nodes.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
allowed_rules[rule_key] = rule_value
return allowed_rules
2023-07-05 16:19:43 +01:00
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
- Which ACL rules already exist, - otherwise:
- The agent would perminently get stuck in a loop of performing the same action over and over.
(best action is to block something, but its already blocked but doesn't know this)
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
if it doesnt know what rules exist)
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
based on the observation space. Additionally, potentially in the future, once a node state
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
If a service node becomes compromised there's a decision to make - do we block that service?
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
Cons: Will block a green IER, decreasing the reward
We decide to block the service.
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
2023-07-05 16:19:43 +01:00
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
_, _, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# 1. Check if node is compromised. If so we want to block its outwards services
# a. If it is comprimised check if there's an allow rule we should delete.
# cons: might delete a multi-rule from any source node (ANY -> x)
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
found_action = False
for service_num, service_states in enumerate(s):
for x, service_state in enumerate(service_states):
if service_state == "COMPROMISED":
action_source_id = x + 1 # +1 as 0 is any
action_destination_id = "ANY"
action_protocol = service_num + 1 # +1 as 0 is any
action_port = "ANY"
allow_rules = self.get_allow_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
deny_rules = self.get_deny_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
if len(allow_rules) > 0:
# Check if there's an allow rule we should delete
rule = list(allow_rules.values())[0]
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
elif len(deny_rules) > 0:
# TODO OPTIONAL
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
# to create another
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
continue
else:
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
action_decision = "CREATE"
action_permission = "DENY"
break
if found_action:
break
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
# add an Allow rule if the green IER is being blocked.
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
# If there's a DENY rule delete it if:
# - There isn't already a deny rule
# - It doesnt allows a comprimised node to become operational.
# b. Add an ALLOW rule if:
# - There isn't already an allow rule
# - It doesnt allows a comprimised node to become operational
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
action_decision = "CREATE"
action_permission = "ALLOW"
action_source_id = int(ier.get_source_node_id())
action_destination_id = int(ier.get_dest_node_id())
action_protocol_name = ier.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = ier.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
if found_action:
action = [
action_decision,
action_permission,
action_source_id,
action_destination_id,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
else:
# If no good/useful action has been found, just perform a nothing action
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
return action
2023-07-05 16:19:43 +01:00
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
of previous actions taken and NO reward feedback.
We rely on randomness to select the precise action, as we want to
block all traffic originating from a compromised node, without being
able to tell:
1. Which ACL rules already exist
2. Which actions the agent has already tried.
There is a high probability that the correct rule will not be deleted
before the state becomes overwhelmed.
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
2023-07-05 16:19:43 +01:00
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block
# Bad assumption that number of protocols equals number of ports
# AND no rules exist with an ANY port
action_port = np.random.choice(list(range(1, len(s) + 1)))
action = [
action_decision,
action_permission,
action_source_ip,
action_destination_ip,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good/useful action has been found, just perform a nothing action
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, action_dict)
return nothing_action