Add docstrings and type hints.
This commit is contained in:
@@ -10,19 +10,19 @@ class AccessControlList:
|
||||
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
|
||||
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
Checks for IP address matches.
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
|
||||
Args:
|
||||
_rule: The rule being checked
|
||||
_source_ip_address: the source IP address to compare
|
||||
_dest_ip_address: the destination IP address to compare
|
||||
|
||||
Returns:
|
||||
True if match; False otherwise.
|
||||
:param _rule: The rule object to check
|
||||
:type _rule: ACLRule
|
||||
:param _source_ip_address: Source IP address to compare
|
||||
:type _source_ip_address: str
|
||||
:param _dest_ip_address: Destination IP address to compare
|
||||
:type _dest_ip_address: str
|
||||
:return: True if there is a match, otherwise False.
|
||||
:rtype: bool
|
||||
"""
|
||||
if (
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any, Dict
|
||||
import numpy as np
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
@@ -11,13 +13,15 @@ from primaite.agents.utils import (
|
||||
)
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
@@ -26,7 +30,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]):
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[Any, Any]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
@@ -54,8 +60,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]):
|
||||
"""Get list of ACL rules which are relevant to an IER
|
||||
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
@@ -76,7 +82,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
@@ -102,9 +110,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, Any]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
@@ -126,19 +135,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_matching_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get matching ACL rules.
|
||||
source_node_id: str,
|
||||
dest_node_id: str,
|
||||
protocol: str,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param source_node_id: Source node
|
||||
:type source_node_id: str
|
||||
:param dest_node_id: Destination nodes
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: str
|
||||
:param port: Network port
|
||||
:type port: str
|
||||
:param acl: Access Control list which will be filtered
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The environment's node directory.
|
||||
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
|
||||
:param services_list: List of services registered for the environment.
|
||||
:type services_list: List[str]
|
||||
:return: Filtered version of 'acl'
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
@@ -158,19 +180,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_allow_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the ALLOW ACL rules.
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
@@ -191,19 +227,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
def get_deny_acl_rules(
|
||||
self,
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the DENY ACL rules.
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[str, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
@@ -222,7 +272,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs):
|
||||
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
@@ -250,8 +300,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
# obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
@@ -387,7 +439,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs):
|
||||
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
|
||||
"""Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
@@ -405,8 +457,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
Currently, a deny rule does not overwrite an allow rule. The allow
|
||||
rules must be deleted.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
@@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.common.enums import (
|
||||
@@ -10,15 +12,17 @@ from primaite.common.enums import (
|
||||
)
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
@@ -33,15 +37,18 @@ def transform_action_node_readable(action):
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
@@ -58,7 +65,7 @@ def transform_action_acl_readable(action):
|
||||
return new_action
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
def is_valid_node_action(action: List[int]) -> bool:
|
||||
"""Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
@@ -67,8 +74,11 @@ def is_valid_node_action(action):
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
@@ -93,7 +103,7 @@ def is_valid_node_action(action):
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
def is_valid_acl_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
@@ -103,8 +113,11 @@ def is_valid_acl_action(action):
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
@@ -126,12 +139,15 @@ def is_valid_acl_action(action):
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
def is_valid_acl_action_extra(action: List[int]) -> bool:
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
@@ -150,15 +166,16 @@ def is_valid_acl_action_extra(action):
|
||||
return True
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs):
|
||||
"""
|
||||
Transform list of transactions to readable list of each observation property.
|
||||
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
@@ -174,14 +191,16 @@ def transform_change_obs_readable(obs):
|
||||
return new_obs
|
||||
|
||||
|
||||
def transform_obs_readable(obs):
|
||||
"""
|
||||
Transform observation to readable format.
|
||||
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform observation to readable format.
|
||||
|
||||
example
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
@@ -191,21 +210,23 @@ def transform_obs_readable(obs):
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_new_obs(obs, num_nodes=10):
|
||||
"""
|
||||
Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs: observation in the 'old' (NodeLinkTable) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:return: reformatted observation
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Convert to old observation.
|
||||
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
|
||||
"""Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
@@ -217,8 +238,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
|
||||
:param obs: observation in the 'new' (MultiDiscrete) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: number of links in the network, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: number of services on the network, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: 2-d BOX observation space, in the same format as NodeLinkTable
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
@@ -240,17 +270,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
def describe_obs_change(
|
||||
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
|
||||
) -> str:
|
||||
"""Build a string describing the difference between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
:param obs1: First observation
|
||||
:type obs1: np.ndarray
|
||||
:param obs2: Second observation
|
||||
:type obs2: np.ndarray
|
||||
:param num_nodes: How many nodes are in the network laydown, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: How many links are in the network laydown, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: How many services are configured for this scenario, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: A multi-line string with a human-readable description of the difference.
|
||||
:rtype: str
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
|
||||
Reference in New Issue
Block a user