Format docstrings
This commit is contained in:
@@ -55,4 +55,4 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
html_theme = "furo"
|
||||
html_static_path = ["_static"]
|
||||
html_favicon = 'source/primaite.ico'
|
||||
html_favicon = "source/primaite.ico"
|
||||
|
||||
@@ -66,8 +66,7 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
|
||||
|
||||
# region Setup Logging
|
||||
class _LevelFormatter(Formatter):
|
||||
"""
|
||||
A custom level-specific formatter.
|
||||
"""A custom level-specific formatter.
|
||||
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
@@ -135,8 +134,7 @@ _LOGGER.addHandler(_FILE_HANDLER)
|
||||
|
||||
|
||||
def getLogger(name: str) -> Logger: # noqa
|
||||
"""
|
||||
Get a PrimAITE logger.
|
||||
"""Get a PrimAITE logger.
|
||||
|
||||
:param name: The logger name. Use ``__name__``.
|
||||
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
|
||||
|
||||
@@ -13,8 +13,7 @@ class AccessControlList:
|
||||
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
Checks for IP address matches.
|
||||
"""Checks for IP address matches.
|
||||
|
||||
Args:
|
||||
_rule: The rule being checked
|
||||
@@ -35,8 +34,7 @@ class AccessControlList:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
"""Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
@@ -62,8 +60,7 @@ class AccessControlList:
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Adds a new rule.
|
||||
"""Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -77,8 +74,7 @@ class AccessControlList:
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Removes a rule.
|
||||
"""Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -100,8 +96,7 @@ class AccessControlList:
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
"""Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
|
||||
@@ -6,8 +6,7 @@ class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
_permission: The permission (ALLOW or DENY)
|
||||
@@ -23,8 +22,7 @@ class ACLRule:
|
||||
self.port = _port
|
||||
|
||||
def __hash__(self):
|
||||
"""
|
||||
Override the hash function.
|
||||
"""Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
@@ -40,8 +38,7 @@ class ACLRule:
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
"""Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
@@ -49,8 +46,7 @@ class ACLRule:
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
"""Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
@@ -58,8 +54,7 @@ class ACLRule:
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
"""Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
@@ -67,8 +62,7 @@ class ACLRule:
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
"""Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
@@ -76,8 +70,7 @@ class ACLRule:
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets the port attribute.
|
||||
"""Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
|
||||
@@ -21,8 +21,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
"""Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
@@ -39,11 +38,10 @@ def get_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
"""
|
||||
An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
"""An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from
|
||||
with all implemented abstract methods implemented.
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -186,8 +184,7 @@ class AgentSessionABC(ABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
"""Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -204,8 +201,7 @@ class AgentSessionABC(ABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
"""Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -293,11 +289,10 @@ class AgentSessionABC(ABC):
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
"""An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from
|
||||
with all implemented abstract methods implemented.
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
@@ -325,8 +320,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
"""Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -340,8 +334,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
"""Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
|
||||
@@ -23,8 +23,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(self, green_iers, acl, nodes):
|
||||
"""
|
||||
Get blocked green IERs.
|
||||
"""Get blocked green IERs.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -46,8 +45,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get matching ACL rules for an IER.
|
||||
"""Get matching ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -63,8 +61,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
"""Get blocking ACL rules for an IER.
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
@@ -83,8 +80,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -108,8 +104,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get matching ACL rules.
|
||||
"""Get matching ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -140,8 +135,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the ALLOW ACL rules.
|
||||
"""Get the ALLOW ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -173,8 +167,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""
|
||||
Get the DENY ACL rules.
|
||||
"""Get the DENY ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -197,8 +190,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs):
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
"""Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
"""Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
|
||||
@@ -128,8 +128,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
"""Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -147,8 +146,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
"""Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
|
||||
@@ -77,8 +77,7 @@ class SB3Agent(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
"""Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -98,8 +97,7 @@ class SB3Agent(AgentSessionABC):
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
"""Evaluate the agent.
|
||||
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
|
||||
@@ -3,8 +3,7 @@ from primaite.agents.utils import get_new_action, transform_action_acl_enum, tra
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Random Agent.
|
||||
"""A Random Agent.
|
||||
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
@@ -14,11 +13,9 @@ class RandomAgent(HardCodedAgentSessionABC):
|
||||
|
||||
|
||||
class DummyAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Dummy Agent.
|
||||
"""A Dummy Agent.
|
||||
|
||||
All action spaces setup so dummy action is always 0 regardless of action
|
||||
type used.
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
@@ -26,8 +23,7 @@ class DummyAgent(HardCodedAgentSessionABC):
|
||||
|
||||
|
||||
class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing ACL agent.
|
||||
"""A do nothing ACL agent.
|
||||
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
@@ -41,8 +37,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
|
||||
class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing Node agent.
|
||||
"""A do nothing Node agent.
|
||||
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
@@ -11,8 +11,7 @@ from primaite.common.enums import (
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
@@ -34,8 +33,7 @@ def transform_action_node_readable(action):
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
"""Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
@@ -94,8 +92,7 @@ def is_valid_node_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
"""Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
@@ -127,8 +124,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
"""Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -151,8 +147,7 @@ def is_valid_acl_action_extra(action):
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs):
|
||||
"""
|
||||
Transform list of transactions to readable list of each observation property.
|
||||
"""Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
@@ -174,8 +169,7 @@ def transform_change_obs_readable(obs):
|
||||
|
||||
|
||||
def transform_obs_readable(obs):
|
||||
"""
|
||||
Transform observation to readable format.
|
||||
"""Transform observation to readable format.
|
||||
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
@@ -191,8 +185,7 @@ def transform_obs_readable(obs):
|
||||
|
||||
|
||||
def convert_to_new_obs(obs, num_nodes=10):
|
||||
"""
|
||||
Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -203,8 +196,7 @@ def convert_to_new_obs(obs, num_nodes=10):
|
||||
|
||||
|
||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Convert to old observation.
|
||||
"""Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
@@ -240,8 +232,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
|
||||
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
"""Return string describing change between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
@@ -269,8 +260,7 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change, is_link):
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
"""Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
@@ -305,8 +295,7 @@ def _describe_obs_change_helper(obs_change, is_link):
|
||||
|
||||
|
||||
def transform_action_node_enum(action):
|
||||
"""
|
||||
Convert a node action from readable string format, to enumerated format.
|
||||
"""Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
@@ -337,8 +326,7 @@ def transform_action_node_enum(action):
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
@@ -360,8 +348,7 @@ def transform_action_node_readable(action):
|
||||
|
||||
|
||||
def node_action_description(action):
|
||||
"""
|
||||
Generate string describing a node-based action.
|
||||
"""Generate string describing a node-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -388,8 +375,7 @@ def node_action_description(action):
|
||||
|
||||
|
||||
def transform_action_acl_enum(action):
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
"""Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -411,8 +397,7 @@ def transform_action_acl_enum(action):
|
||||
|
||||
|
||||
def acl_action_description(action):
|
||||
"""
|
||||
Generate string describing an acl-based action.
|
||||
"""Generate string describing an acl-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -432,8 +417,7 @@ def acl_action_description(action):
|
||||
|
||||
|
||||
def get_node_of_ip(ip, node_dict):
|
||||
"""
|
||||
Get the node ID of an IP address.
|
||||
"""Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
|
||||
@@ -480,8 +464,7 @@ def is_valid_node_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
"""Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
@@ -513,8 +496,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
"""Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -537,8 +519,7 @@ def is_valid_acl_action_extra(action):
|
||||
|
||||
|
||||
def get_new_action(old_action, action_dict):
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
|
||||
@@ -28,8 +28,7 @@ def build_dirs():
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True):
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
"""Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
:param overwrite: If True, will overwrite existing demo notebooks.
|
||||
"""
|
||||
@@ -40,8 +39,7 @@ def reset_notebooks(overwrite: bool = True):
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
"""Print the PrimAITE log file.
|
||||
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
@@ -61,8 +59,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
|
||||
|
||||
@app.command()
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
"""
|
||||
View or set the PrimAITE Log Level.
|
||||
"""View or set the PrimAITE Log Level.
|
||||
|
||||
To View, simply call: primaite log-level
|
||||
|
||||
@@ -113,8 +110,7 @@ def clean_up():
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True):
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
"""Perform the PrimAITE first-time setup.
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
@@ -152,8 +148,7 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
"""Run a PrimAITE session.
|
||||
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
@@ -178,8 +173,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
"""View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
_name: The protocol name
|
||||
@@ -16,8 +15,7 @@ class Protocol(object):
|
||||
self.load = 0 # bps
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the protocol name.
|
||||
"""Gets the protocol name.
|
||||
|
||||
Returns:
|
||||
The protocol name
|
||||
@@ -25,8 +23,7 @@ class Protocol(object):
|
||||
return self.name
|
||||
|
||||
def get_load(self):
|
||||
"""
|
||||
Gets the protocol load.
|
||||
"""Gets the protocol load.
|
||||
|
||||
Returns:
|
||||
The protocol load (bps)
|
||||
@@ -34,8 +31,7 @@ class Protocol(object):
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load):
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
"""Adds load to the protocol.
|
||||
|
||||
Args:
|
||||
_load: The load to add
|
||||
|
||||
@@ -8,8 +8,7 @@ class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
:param name: The service name.
|
||||
:param port: The service port.
|
||||
|
||||
@@ -12,8 +12,7 @@ _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy lay down config dict to the new format.
|
||||
"""Convert a legacy lay down config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy lay down config dict.
|
||||
"""
|
||||
@@ -22,12 +21,10 @@ def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> D
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
"""Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise False.
|
||||
:return: The lay down config as a dict.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
"""
|
||||
@@ -53,8 +50,7 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
"""The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -68,8 +64,7 @@ def ddos_basic_one_config_path() -> Path:
|
||||
|
||||
|
||||
def ddos_basic_two_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
"""The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -83,8 +78,7 @@ def ddos_basic_two_config_path() -> Path:
|
||||
|
||||
|
||||
def dos_very_basic_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
"""The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -98,8 +92,7 @@ def dos_very_basic_config_path() -> Path:
|
||||
|
||||
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
"""The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
|
||||
@@ -24,8 +24,7 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
"""The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -180,8 +179,7 @@ class TrainingConfig:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
"""Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
@@ -236,8 +234,7 @@ class TrainingConfig:
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
"""Read in a training config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
@@ -281,18 +278,14 @@ def convert_legacy_training_config_dict(
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
"""Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy training configs don't have agent_identifier
|
||||
values.
|
||||
:param action_type: The action space type to set as legacy training configs don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs don't have num_steps values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
@@ -312,8 +305,7 @@ def convert_legacy_training_config_dict(
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
"""Maps legacy training config keys to the new format keys.
|
||||
|
||||
:param legacy_key: A legacy training config key.
|
||||
:return: The mapped key.
|
||||
|
||||
@@ -22,8 +22,7 @@ def plot_av_reward_per_episode(
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
"""Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
|
||||
@@ -376,8 +376,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
|
||||
This allows users to configure observation spaces by mixing and matching components.
|
||||
Each component can also define further parameters to make them more flexible.
|
||||
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
||||
further parameters to make them more flexible.
|
||||
"""
|
||||
|
||||
_REGISTRY: Final[Dict[str, type]] = {
|
||||
|
||||
@@ -67,14 +67,12 @@ class Primaite(Env):
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
"""
|
||||
The Primaite constructor.
|
||||
"""The Primaite constructor.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
:param timestamp_str: The session timestamp in the format: <yyyy-mm-dd>_<hh-mm- ss>.
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
@@ -256,8 +254,7 @@ class Primaite(Env):
|
||||
self.total_step_count = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
AI Gym Reset function.
|
||||
"""AI Gym Reset function.
|
||||
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
@@ -293,8 +290,7 @@ class Primaite(Env):
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
AI Gym Step function.
|
||||
"""AI Gym Step function.
|
||||
|
||||
Args:
|
||||
action: Action space from agent
|
||||
@@ -432,8 +428,7 @@ class Primaite(Env):
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
"""Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
@@ -452,8 +447,7 @@ class Primaite(Env):
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""
|
||||
Applies agent actions to the nodes.
|
||||
"""Applies agent actions to the nodes.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
@@ -540,8 +534,7 @@ class Primaite(Env):
|
||||
return
|
||||
|
||||
def apply_actions_to_acl(self, _action):
|
||||
"""
|
||||
Applies agent actions to the Access Control List [TO DO].
|
||||
"""Applies agent actions to the Access Control List [TO DO].
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
@@ -618,8 +611,7 @@ class Primaite(Env):
|
||||
return
|
||||
|
||||
def apply_time_based_updates(self):
|
||||
"""
|
||||
Updates anything that needs to count down and then change state.
|
||||
"""Updates anything that needs to count down and then change state.
|
||||
|
||||
e.g. reset / patching status
|
||||
"""
|
||||
@@ -716,8 +708,7 @@ class Primaite(Env):
|
||||
print("Environment configuration loaded")
|
||||
|
||||
def create_node(self, item):
|
||||
"""
|
||||
Creates a node from config data.
|
||||
"""Creates a node from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -797,8 +788,7 @@ class Primaite(Env):
|
||||
self.network_reference.add_nodes_from([node_ref])
|
||||
|
||||
def create_link(self, item: Dict):
|
||||
"""
|
||||
Creates a link from config data.
|
||||
"""Creates a link from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -841,8 +831,7 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_green_ier(self, item):
|
||||
"""
|
||||
Creates a green IER from config data.
|
||||
"""Creates a green IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -882,8 +871,7 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
Creates a red IER from config data.
|
||||
"""Creates a red IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -912,8 +900,7 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_green_pol(self, item):
|
||||
"""
|
||||
Creates a green PoL object from config data.
|
||||
"""Creates a green PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -946,8 +933,7 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_red_pol(self, item):
|
||||
"""
|
||||
Creates a red PoL object from config data.
|
||||
"""Creates a red PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -987,8 +973,7 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_acl_rule(self, item):
|
||||
"""
|
||||
Creates an ACL rule from config data.
|
||||
"""Creates an ACL rule from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -1008,8 +993,7 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
"""
|
||||
Creates a list of services (enum) from config data.
|
||||
"""Creates a list of services (enum) from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the services
|
||||
@@ -1024,8 +1008,7 @@ class Primaite(Env):
|
||||
self.num_services = len(self.services_list)
|
||||
|
||||
def create_ports_list(self, ports):
|
||||
"""
|
||||
Creates a list of ports from config data.
|
||||
"""Creates a list of ports from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the ports
|
||||
@@ -1048,8 +1031,7 @@ class Primaite(Env):
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info.
|
||||
"""Extracts action_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing action info
|
||||
@@ -1069,11 +1051,9 @@ class Primaite(Env):
|
||||
self.obs_config = obs_config
|
||||
|
||||
def reset_environment(self):
|
||||
"""
|
||||
# Resets environment.
|
||||
"""# Resets environment.
|
||||
|
||||
Uses config data config data in order to build the environment
|
||||
configuration.
|
||||
Uses config data config data in order to build the environment configuration.
|
||||
"""
|
||||
for item in self.lay_down_config:
|
||||
if item["item_type"] == "NODE":
|
||||
@@ -1095,8 +1075,7 @@ class Primaite(Env):
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
def reset_node(self, item):
|
||||
"""
|
||||
Resets the statuses of a node.
|
||||
"""Resets the statuses of a node.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -1143,8 +1122,7 @@ class Primaite(Env):
|
||||
pass
|
||||
|
||||
def create_node_action_dict(self):
|
||||
"""
|
||||
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
|
||||
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
||||
|
||||
@@ -1157,7 +1135,6 @@ class Primaite(Env):
|
||||
5: [1, 3, 1, 0],
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
# reserve 0 action to be a nothing action
|
||||
actions = {0: [1, 0, 0, 0]}
|
||||
@@ -1209,11 +1186,9 @@ class Primaite(Env):
|
||||
return actions
|
||||
|
||||
def create_node_and_acl_action_dict(self):
|
||||
"""
|
||||
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
"""Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
|
||||
The dictionary contains actions of both Node and ACL action types.
|
||||
|
||||
"""
|
||||
node_action_dict = self.create_node_action_dict()
|
||||
acl_action_dict = self.create_acl_action_dict()
|
||||
|
||||
@@ -21,8 +21,7 @@ def calculate_reward_function(
|
||||
step_count,
|
||||
config_values,
|
||||
):
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
"""Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
@@ -95,8 +94,7 @@ def calculate_reward_function(
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
"""Calculates score relating to the hardware state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -144,8 +142,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
"""Calculates score relating to the Software State of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -195,8 +192,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
"""Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -267,8 +263,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
"""Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
|
||||
@@ -9,8 +9,7 @@ class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
_id: The IER id
|
||||
@@ -30,8 +29,7 @@ class Link(object):
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol):
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
"""Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
@@ -39,8 +37,7 @@ class Link(object):
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets link ID.
|
||||
"""Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
@@ -48,8 +45,7 @@ class Link(object):
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self):
|
||||
"""
|
||||
Gets source node name.
|
||||
"""Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
@@ -57,8 +53,7 @@ class Link(object):
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self):
|
||||
"""
|
||||
Gets destination node name.
|
||||
"""Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
@@ -66,8 +61,7 @@ class Link(object):
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self):
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
"""Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
@@ -75,8 +69,7 @@ class Link(object):
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self):
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
"""Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
@@ -84,8 +77,7 @@ class Link(object):
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self):
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
"""Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
@@ -96,8 +88,7 @@ class Link(object):
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol, _load):
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
"""Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
|
||||
@@ -25,8 +25,7 @@ class ActiveNode(Node):
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
@@ -52,8 +51,7 @@ class ActiveNode(Node):
|
||||
|
||||
@property
|
||||
def software_state(self) -> SoftwareState:
|
||||
"""
|
||||
Get the software_state.
|
||||
"""Get the software_state.
|
||||
|
||||
:return: The software_state.
|
||||
"""
|
||||
@@ -61,8 +59,7 @@ class ActiveNode(Node):
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState):
|
||||
"""
|
||||
Get the software_state.
|
||||
"""Get the software_state.
|
||||
|
||||
:param software_state: Software State.
|
||||
"""
|
||||
@@ -80,8 +77,7 @@ class ActiveNode(Node):
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
"""Sets Software State if the node is not compromised.
|
||||
|
||||
Args:
|
||||
software_state: Software State
|
||||
@@ -107,8 +103,7 @@ class ActiveNode(Node):
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState):
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
"""Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
@@ -134,8 +129,7 @@ class ActiveNode(Node):
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
"""Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
|
||||
@@ -18,8 +18,7 @@ class Node:
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
|
||||
@@ -15,8 +15,7 @@ class NodeStateInstructionGreen(object):
|
||||
_service_name,
|
||||
_state,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
_id: The node state instruction id
|
||||
@@ -36,8 +35,7 @@ class NodeStateInstructionGreen(object):
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets the start step.
|
||||
"""Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
@@ -45,8 +43,7 @@ class NodeStateInstructionGreen(object):
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets the end step.
|
||||
"""Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
@@ -54,8 +51,7 @@ class NodeStateInstructionGreen(object):
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
"""
|
||||
Gets the node ID.
|
||||
"""Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
@@ -63,8 +59,7 @@ class NodeStateInstructionGreen(object):
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
"""Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
@@ -72,8 +67,7 @@ class NodeStateInstructionGreen(object):
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""
|
||||
Gets the service name.
|
||||
"""Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
@@ -81,8 +75,7 @@ class NodeStateInstructionGreen(object):
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
"""Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
|
||||
@@ -23,8 +23,7 @@ class NodeStateInstructionRed(object):
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
_id: The node state instruction id
|
||||
@@ -52,8 +51,7 @@ class NodeStateInstructionRed(object):
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets the start step.
|
||||
"""Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
@@ -61,8 +59,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets the end step.
|
||||
"""Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
@@ -70,8 +67,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
"""
|
||||
Gets the node ID.
|
||||
"""Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
@@ -79,8 +75,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
"""
|
||||
Gets the initiator.
|
||||
"""Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
@@ -88,8 +83,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self) -> NodePOLType:
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
"""Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
@@ -97,8 +91,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""
|
||||
Gets the service name.
|
||||
"""Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
@@ -106,8 +99,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
"""Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
@@ -115,8 +107,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
"""Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
@@ -124,8 +115,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
"""Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
@@ -133,8 +123,7 @@ class NodeStateInstructionRed(object):
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
"""Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
|
||||
@@ -17,8 +17,7 @@ class PassiveNode(Node):
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
@@ -32,8 +31,7 @@ class PassiveNode(Node):
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
"""
|
||||
Gets the node IP address as an empty string.
|
||||
"""Gets the node IP address as an empty string.
|
||||
|
||||
No concept of IP address for passive nodes for now.
|
||||
|
||||
|
||||
@@ -26,8 +26,7 @@ class ServiceNode(ActiveNode):
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
@@ -53,16 +52,14 @@ class ServiceNode(ActiveNode):
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service):
|
||||
"""
|
||||
Adds a service to the node.
|
||||
"""Adds a service to the node.
|
||||
|
||||
:param service: The service to add
|
||||
"""
|
||||
self.services[service.name] = service
|
||||
|
||||
def has_service(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is on a node.
|
||||
"""Indicates whether a service is on a node.
|
||||
|
||||
:param protocol_name: The service (protocol)e.
|
||||
:return: True if service (protocol) is on the node, otherwise False.
|
||||
@@ -73,12 +70,10 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
|
||||
def service_running(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node.
|
||||
"""Indicates whether a service is in a running state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in a running state on the
|
||||
node, otherwise False.
|
||||
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
@@ -89,12 +84,10 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
|
||||
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
"""Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in an overwhelmed state on the
|
||||
node, otherwise False.
|
||||
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
@@ -105,8 +98,7 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
"""Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
@@ -134,8 +126,7 @@ class ServiceNode(ActiveNode):
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
"""Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
Done if the software_state is not "compromised".
|
||||
|
||||
@@ -161,8 +152,7 @@ class ServiceNode(ActiveNode):
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name):
|
||||
"""
|
||||
Gets the state of a service.
|
||||
"""Gets the state of a service.
|
||||
|
||||
:return: The software_state of the service.
|
||||
"""
|
||||
|
||||
@@ -10,8 +10,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session():
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
"""Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
Currently only works on Windows OS.
|
||||
|
||||
|
||||
@@ -25,8 +25,7 @@ def apply_iers(
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
"""Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -218,8 +217,7 @@ def apply_node_pol(
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
"""Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
"""Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
@@ -22,8 +21,7 @@ class IER(object):
|
||||
_mission_criticality,
|
||||
_running=False,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
_id: The IER id
|
||||
@@ -49,8 +47,7 @@ class IER(object):
|
||||
self.running = _running
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets IER ID.
|
||||
"""Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
@@ -58,8 +55,7 @@ class IER(object):
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets IER start step.
|
||||
"""Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
@@ -67,8 +63,7 @@ class IER(object):
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets IER end step.
|
||||
"""Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
@@ -76,8 +71,7 @@ class IER(object):
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
"""
|
||||
Gets IER load.
|
||||
"""Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
@@ -85,8 +79,7 @@ class IER(object):
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
"""
|
||||
Gets IER protocol.
|
||||
"""Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
@@ -94,8 +87,7 @@ class IER(object):
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets IER port.
|
||||
"""Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
@@ -103,8 +95,7 @@ class IER(object):
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
"""Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
@@ -112,8 +103,7 @@ class IER(object):
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
"""Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
@@ -121,8 +111,7 @@ class IER(object):
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
"""Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
@@ -130,8 +119,7 @@ class IER(object):
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
"""Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
@@ -139,8 +127,7 @@ class IER(object):
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self):
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
"""Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
|
||||
@@ -24,8 +24,7 @@ def apply_red_agent_iers(
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
"""Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -214,8 +213,7 @@ def apply_red_agent_node_pol(
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
"""Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
@@ -297,8 +295,7 @@ def apply_red_agent_node_pol(
|
||||
|
||||
|
||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
"""
|
||||
Checks if the RED IER is incoming.
|
||||
"""Checks if the RED IER is incoming.
|
||||
|
||||
TODO: Write more descriptive docstring with params and returns.
|
||||
"""
|
||||
|
||||
@@ -18,11 +18,9 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession class.
|
||||
"""The PrimaiteSession class.
|
||||
|
||||
Provides a single learning and evaluation entry point for all training
|
||||
and lay down configurations.
|
||||
Provides a single learning and evaluation entry point for all training and lay down configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -30,8 +28,7 @@ class PrimaiteSession:
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
"""The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
@@ -125,8 +122,7 @@ class PrimaiteSession:
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
"""Train the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
@@ -137,8 +133,7 @@ class PrimaiteSession:
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
"""Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
|
||||
@@ -12,11 +12,9 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True):
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
"""Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
notebooks on or off.
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
|
||||
"""
|
||||
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
|
||||
for subdir, dirs, files in os.walk(notebooks_package_data_root):
|
||||
|
||||
@@ -11,11 +11,9 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing=True):
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
"""Resets the example config files in the users app config directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
config on or off.
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
|
||||
"""
|
||||
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
"""
|
||||
Handles creation of application directories and user directories.
|
||||
"""Handles creation of application directories and user directories.
|
||||
|
||||
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
|
||||
app directories in the correct locations based on the users OS.
|
||||
|
||||
@@ -10,8 +10,7 @@ class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
|
||||
"""
|
||||
Transaction constructor.
|
||||
"""Transaction constructor.
|
||||
|
||||
:param agent_identifier: An identifier for the agent in use
|
||||
:param episode_number: The episode number
|
||||
@@ -39,8 +38,7 @@ class Transaction(object):
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
"""Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
:return: A tuple consisting of (header, data).
|
||||
"""
|
||||
@@ -69,8 +67,7 @@ class Transaction(object):
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
"""Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
:param action_space: The action space
|
||||
:return: The action space as an array of strings
|
||||
@@ -82,12 +79,10 @@ def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
"""Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
:param obs_space: The observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the
|
||||
observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
:param obs_features: The number of features associated with the asset
|
||||
:return: The observation space as an array of strings
|
||||
"""
|
||||
|
||||
@@ -10,8 +10,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
"""
|
||||
Get PrimAITE package data.
|
||||
"""Get PrimAITE package data.
|
||||
|
||||
:Example:
|
||||
|
||||
|
||||
@@ -7,11 +7,9 @@ import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
"""Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean
|
||||
reward that episode.
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode cdv as a dict.
|
||||
|
||||
@@ -12,8 +12,7 @@ _LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SessionOutputWriter:
|
||||
"""
|
||||
A session output writer class.
|
||||
"""A session output writer class.
|
||||
|
||||
Is used to write session outputs to csv file.
|
||||
"""
|
||||
@@ -65,11 +64,9 @@ class SessionOutputWriter:
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
"""
|
||||
Write a row of session data.
|
||||
"""Write a row of session data.
|
||||
|
||||
:param data: The row of data to write. Can be a Tuple or an instance
|
||||
of Transaction.
|
||||
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
|
||||
"""
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
|
||||
Reference in New Issue
Block a user