Update some param descriptions for hardcoded agent

This commit is contained in:
Marek Wolan
2023-07-05 09:54:50 +01:00
parent e199dc52c0
commit 3ced1a1913

View File

@@ -1,4 +1,6 @@
from typing import Any, Dict
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
@@ -7,7 +9,9 @@ from primaite.agents.utils import (
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC):
@@ -22,12 +26,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes):
"""
Get blocked green IERs.
def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]):
"""Get blocked green IERs.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param green_iers: Green IERs to check for being
:type green_iers: Dict[str, IER]
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
:rtype: Dict[str, IER]
"""
blocked_green_iers = {}
@@ -45,12 +54,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get matching ACL rules for an IER.
def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]):
"""Get list of ACL rules which are relevant to an IER
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
@@ -58,11 +72,11 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
# I can't find where this function 'get_relevant_rules' is defined...
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
"""
Get blocking ACL rules for an IER.
@@ -70,8 +84,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
@@ -82,12 +102,18 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_rules
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
"""
Get all allowing ACL rules for an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)