Update some param descriptions for hardcoded agent
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
import numpy as np
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
@@ -7,7 +9,9 @@ from primaite.agents.utils import (
|
||||
transform_action_acl_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
@@ -22,12 +26,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(self, green_iers, acl, nodes):
|
||||
"""
|
||||
Get blocked green IERs.
|
||||
def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]):
|
||||
"""Get blocked green IERs.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param green_iers: Green IERs to check for being
|
||||
:type green_iers: Dict[str, IER]
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
|
||||
:rtype: Dict[str, IER]
|
||||
"""
|
||||
blocked_green_iers = {}
|
||||
|
||||
@@ -45,12 +54,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get matching ACL rules for an IER.
|
||||
def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]):
|
||||
"""Get list of ACL rules which are relevant to an IER
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
@@ -58,11 +72,11 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
# I can't find where this function 'get_relevant_rules' is defined...
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
@@ -70,8 +84,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
(No ALLOW rule, therefore blocked).
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
@@ -82,12 +102,18 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user