Update some param descriptions for hardcoded agent

This commit is contained in:
Marek Wolan
2023-07-05 09:54:50 +01:00
parent e199dc52c0
commit 3ced1a1913

View File

@@ -1,4 +1,6 @@
from typing import Any, Dict
import numpy as np import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import ( from primaite.agents.utils import (
@@ -7,7 +9,9 @@ from primaite.agents.utils import (
transform_action_acl_enum, transform_action_acl_enum,
transform_change_obs_readable, transform_change_obs_readable,
) )
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView from primaite.common.enums import HardCodedAgentView
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC): class HardCodedACLAgent(HardCodedAgentSessionABC):
@@ -22,12 +26,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
# history and reward feedback # history and reward feedback
return self._calculate_action_full_view(obs) return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes): def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]):
""" """Get blocked green IERs.
Get blocked green IERs.
TODO: Add params and return in docstring. :param green_iers: Green IERs to check for being
TODO: Typehint params and return. :type green_iers: Dict[str, IER]
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
:rtype: Dict[str, IER]
""" """
blocked_green_iers = {} blocked_green_iers = {}
@@ -45,12 +54,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier, acl, nodes): def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]):
""" """Get list of ACL rules which are relevant to an IER
Get matching ACL rules for an IER.
TODO: Add params and return in docstring. :param ier: Information Exchange Request to query against the ACL list
TODO: Typehint params and return. :type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
""" """
source_node_id = ier.get_source_node_id() source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address source_node_address = nodes[source_node_id].ip_address
@@ -58,11 +72,11 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
dest_node_address = nodes[dest_node_id].ip_address dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP' protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port() port = ier.get_port()
# I can't find where this function 'get_relevant_rules' is defined...
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
""" """
Get blocking ACL rules for an IER. Get blocking ACL rules for an IER.
@@ -70,8 +84,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
Can return empty dict but IER can still be blocked by default Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked). (No ALLOW rule, therefore blocked).
TODO: Add params and return in docstring. :param ier: Information Exchange Request to query against the ACL list
TODO: Typehint params and return. :type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
""" """
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
@@ -82,12 +102,18 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_rules return blocked_rules
def get_allow_acl_rules_for_ier(self, ier, acl, nodes): def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
""" """
Get all allowing ACL rules for an IER. Get all allowing ACL rules for an IER.
TODO: Add params and return in docstring. :param ier: Information Exchange Request to query against the ACL list
TODO: Typehint params and return. :type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
""" """
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)