From 86725064ec9e085f3e94b6ad0f98161417da16f7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 6 Jul 2023 16:08:51 +0100 Subject: [PATCH] Added docstrings to class intialisers --- src/primaite/acl/access_control_list.py | 4 +- src/primaite/acl/acl_rule.py | 18 ++++---- src/primaite/agents/agent.py | 24 ++++++++++- src/primaite/agents/hardcoded_acl.py | 3 +- src/primaite/agents/rllib.py | 12 ++++++ src/primaite/agents/sb3.py | 12 ++++++ src/primaite/agents/simple.py | 12 ++++-- src/primaite/common/protocol.py | 20 +++++---- src/primaite/common/service.py | 14 ++++--- src/primaite/config/training_config.py | 12 ++++-- src/primaite/environment/observations.py | 41 ++++++++++++++----- src/primaite/environment/primaite_env.py | 3 +- src/primaite/links/link.py | 18 ++++---- src/primaite/nodes/active_node.py | 27 ++++++------ src/primaite/nodes/node.py | 20 +++++---- .../nodes/node_state_instruction_green.py | 22 +++++----- .../nodes/node_state_instruction_red.py | 30 +++++++------- src/primaite/nodes/passive_node.py | 20 +++++---- src/primaite/nodes/service_node.py | 26 ++++++------ src/primaite/pol/ier.py | 28 +++++++------ src/primaite/primaite_session.py | 6 ++- src/primaite/transactions/transaction.py | 3 +- src/primaite/utils/session_output_writer.py | 15 ++++++- 23 files changed, 253 insertions(+), 137 deletions(-) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index f6ae3fad..e1d6aa74 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -9,10 +9,12 @@ class AccessControlList: """Access Control List class.""" def __init__(self): + """Initialise an empty AccessControlList.""" self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): - """Checks for IP address matches. + """ + Checks for IP address matches. Args: _rule: The rule being checked diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index 49aebc1b..117c9457 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -3,16 +3,18 @@ class ACLRule: - """Access Control List Rule class. - - :param _permission: The permission (ALLOW or DENY) - :param _source_ip: The source IP address - :param _dest_ip: The destination IP address - :param _protocol: The rule protocol - :param _port: The rule port - """ + """Access Control List Rule class.""" def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port): + """ + Initialise an ACL Rule. + + :param _permission: The permission (ALLOW or DENY) + :param _source_ip: The source IP address + :param _dest_ip: The destination IP address + :param _protocol: The rule protocol + :param _port: The rule port + """ self.permission = _permission self.source_ip = _source_ip self.dest_ip = _dest_ip diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index c68b6df0..7073d795 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -38,7 +38,8 @@ def get_session_path(session_timestamp: datetime) -> Path: class AgentSessionABC(ABC): - """An ABC that manages training and/or evaluation of agents in PrimAITE. + """ + An ABC that manages training and/or evaluation of agents in PrimAITE. This class cannot be directly instantiated and must be inherited from with all implemented abstract methods implemented. @@ -46,6 +47,15 @@ class AgentSessionABC(ABC): @abstractmethod def __init__(self, training_config_path, lay_down_config_path): + """ + Initialise an agent session from config files. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + """ if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path, str]] = training_config_path @@ -289,13 +299,23 @@ class AgentSessionABC(ABC): class HardCodedAgentSessionABC(AgentSessionABC): - """An Agent Session ABC for evaluation deterministic agents. + """ + An Agent Session ABC for evaluation deterministic agents. This class cannot be directly instantiated and must be inherited from with all implemented abstract methods implemented. """ def __init__(self, training_config_path, lay_down_config_path): + """ + Initialise a hardcoded agent session. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + """ super().__init__(training_config_path, lay_down_config_path) self._setup() diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 9ed9fd28..5cc06bdc 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -23,7 +23,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return self._calculate_action_full_view(obs) def get_blocked_green_iers(self, green_iers, acl, nodes): - """Get blocked green IERs. + """ + Get blocked green IERs. TODO: Add params and return in docstring. TODO: Typehint params and return. diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 20503459..044b760f 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -42,6 +42,18 @@ class RLlibAgent(AgentSessionABC): """An AgentSession class that implements a Ray RLlib agent.""" def __init__(self, training_config_path, lay_down_config_path): + """ + Initialise the RLLib Agent training session. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB") + :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` + or `A2C`) + """ super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 58148d1f..b81a0a18 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -19,6 +19,18 @@ class SB3Agent(AgentSessionABC): """An AgentSession class that implements a Stable Baselines3 agent.""" def __init__(self, training_config_path, lay_down_config_path): + """ + Initialise the SB3 Agent training session. + + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3") + :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` + or `A2C`) + """ super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.SB3: msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index df93e56d..b429a2f5 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -3,7 +3,8 @@ from primaite.agents.utils import get_new_action, transform_action_acl_enum, tra class RandomAgent(HardCodedAgentSessionABC): - """A Random Agent. + """ + A Random Agent. Get a completely random action from the action space. """ @@ -13,7 +14,8 @@ class RandomAgent(HardCodedAgentSessionABC): class DummyAgent(HardCodedAgentSessionABC): - """A Dummy Agent. + """ + A Dummy Agent. All action spaces setup so dummy action is always 0 regardless of action type used. """ @@ -23,7 +25,8 @@ class DummyAgent(HardCodedAgentSessionABC): class DoNothingACLAgent(HardCodedAgentSessionABC): - """A do nothing ACL agent. + """ + A do nothing ACL agent. A valid ACL action that has no effect; does nothing. """ @@ -37,7 +40,8 @@ class DoNothingACLAgent(HardCodedAgentSessionABC): class DoNothingNodeAgent(HardCodedAgentSessionABC): - """A do nothing Node agent. + """ + A do nothing Node agent. A valid Node action that has no effect; does nothing. """ diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py index ec67caa3..ad6a1d83 100644 --- a/src/primaite/common/protocol.py +++ b/src/primaite/common/protocol.py @@ -3,17 +3,21 @@ class Protocol(object): - """Protocol class. - - :param _name: The protocol name - """ + """Protocol class.""" def __init__(self, _name): + """ + Initialise a protocol. + + :param _name: The name of the protocol + :type _name: str + """ self.name = _name self.load = 0 # bps def get_name(self): - """Gets the protocol name. + """ + Gets the protocol name. Returns: The protocol name @@ -21,7 +25,8 @@ class Protocol(object): return self.name def get_load(self): - """Gets the protocol load. + """ + Gets the protocol load. Returns: The protocol load (bps) @@ -29,7 +34,8 @@ class Protocol(object): return self.load def add_load(self, _load): - """Adds load to the protocol. + """ + Adds load to the protocol. Args: _load: The load to add diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py index 2d08a3c5..258ac8f9 100644 --- a/src/primaite/common/service.py +++ b/src/primaite/common/service.py @@ -5,14 +5,16 @@ from primaite.common.enums import SoftwareState class Service(object): - """Service class. - - :param name: The service name. - :param port: The service port. - :param software_state: The service SoftwareState. - """ + """Service class.""" def __init__(self, name: str, port: str, software_state: SoftwareState): + """ + Initialise a service. + + :param name: The service name. + :param port: The service port. + :param software_state: The service SoftwareState. + """ self.name = name self.port = port self.software_state = software_state diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 040ef6fa..7bdf7995 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -24,7 +24,8 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training def main_training_config_path() -> Path: - """The path to the example training_config_main.yaml file. + """ + The path to the example training_config_main.yaml file. :return: The file path. """ @@ -234,7 +235,8 @@ class TrainingConfig: def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: - """Read in a training config yaml file. + """ + Read in a training config yaml file. :param file_path: The config file path. :param legacy_file: True if the config file is legacy format, otherwise @@ -278,7 +280,8 @@ def convert_legacy_training_config_dict( action_type: ActionType = ActionType.ANY, num_steps: int = 256, ) -> Dict[str, Any]: - """Convert a legacy training config dict to the new format. + """ + Convert a legacy training config dict to the new format. :param legacy_config_dict: A legacy training config dict. :param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values. @@ -305,7 +308,8 @@ def convert_legacy_training_config_dict( def _get_new_key_from_legacy(legacy_key: str) -> str: - """Maps legacy training config keys to the new format keys. + """ + Maps legacy training config keys to the new format keys. :param legacy_key: A legacy training config key. :return: The mapped key. diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 4d027326..28e85b7f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -25,6 +25,12 @@ class AbstractObservationComponent(ABC): @abstractmethod def __init__(self, env: "Primaite"): + """ + Initialise observation component. + + :param env: Primaite training environment. + :type env: Primaite + """ _LOGGER.info(f"Initialising {self} observation component") self.env: "Primaite" = env self.space: spaces.Space @@ -68,6 +74,11 @@ class NodeLinkTable(AbstractObservationComponent): _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): + """Initialise a NodeLinkTable observation space component. + + :param env: Training environment. + :type env: Primaite + """ super().__init__(env) # 1. Define the shape of your observation space component @@ -192,14 +203,17 @@ class NodeStatuses(AbstractObservationComponent): node2 serviceN state (one for each service), ... ] - - :param env: The environment that forms the basis of the observations - :type env: Primaite """ _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): + """ + Initialise a NodeStatuses observation component. + + :param env: Training environment. + :type env: Primaite + """ super().__init__(env) # 1. Define the shape of your observation space component @@ -288,14 +302,6 @@ class LinkTrafficLevels(AbstractObservationComponent): The lowest category always corresponds to no traffic and the highest category to the link being at max capacity. Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories. - :param env: The environment that forms the basis of the observations - :type env: Primaite - :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, - defaults to False - :type combine_service_traffic: bool, optional - :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value , - defaults to 5 - :type quantisation_levels: int, optional """ _DATA_TYPE: type = np.int64 @@ -306,6 +312,18 @@ class LinkTrafficLevels(AbstractObservationComponent): combine_service_traffic: bool = False, quantisation_levels: int = 5, ): + """ + Initialise a LinkTrafficLevels observation component. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, + defaults to False + :type combine_service_traffic: bool, optional + :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical + value, defaults to 5 + :type quantisation_levels: int, optional + """ if quantisation_levels < 3: _msg = ( f"quantisation_levels must be 3 or more because the lowest and highest levels are " @@ -390,6 +408,7 @@ class ObservationsHandler: } def __init__(self): + """Initialise the observation handler.""" self.registered_obs_components: List[AbstractObservationComponent] = [] # internal the observation space (unflattened version of space if flatten=True) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 29662988..825818fd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -67,7 +67,8 @@ class Primaite(Env): session_path: Path, timestamp_str: str, ): - """The Primaite constructor. + """ + The Primaite constructor. :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index ff73ccc8..5892b8e2 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -6,16 +6,18 @@ from primaite.common.protocol import Protocol class Link(object): - """Link class. - - :param _id: The IER id - :param _bandwidth: The bandwidth of the link (bps) - :param _source_node_name: The name of the source node - :param _dest_node_name: The name of the destination node - :param _protocols: The protocols to add to the link - """ + """Link class.""" def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): + """ + Initialise a Link within the simulated network. + + :param _id: The IER id + :param _bandwidth: The bandwidth of the link (bps) + :param _source_node_name: The name of the source node + :param _dest_node_name: The name of the destination node + :param _protocols: The protocols to add to the link + """ self.id = _id self.bandwidth = _bandwidth self.source_node_name = _source_node_name diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index e20ce0e0..3789b7a4 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -11,19 +11,7 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__) class ActiveNode(Node): - """Active Node class. - - :param node_id: The node ID - :param name: The node name - :param node_type: The node type (enum) - :param priority: The node priority (enum) - :param hardware_state: The node Hardware State - :param ip_address: The node IP address - :param software_state: The node Software State - :param file_system_state: The node file system state - :param config_values: The config values - - """ + """Active Node class.""" def __init__( self, @@ -37,6 +25,19 @@ class ActiveNode(Node): file_system_state: FileSystemState, config_values: TrainingConfig, ): + """ + Initialise an active node. + + :param node_id: The node ID + :param name: The node name + :param node_type: The node type (enum) + :param priority: The node priority (enum) + :param hardware_state: The node Hardware State + :param ip_address: The node IP address + :param software_state: The node Software State + :param file_system_state: The node file system state + :param config_values: The config values + """ super().__init__(node_id, name, node_type, priority, hardware_state, config_values) self.ip_address: str = ip_address # Related to Software diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index b54989bf..9fd5b719 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -7,15 +7,7 @@ from primaite.config.training_config import TrainingConfig class Node: - """Node class. - - :param node_id: The node id. - :param name: The name of the node. - :param node_type: The type of the node. - :param priority: The priority of the node. - :param hardware_state: The state of the node. - :param config_values: Config values. - """ + """Node class.""" def __init__( self, @@ -26,6 +18,16 @@ class Node: hardware_state: HardwareState, config_values: TrainingConfig, ): + """ + Initialise a node. + + :param node_id: The node id. + :param name: The name of the node. + :param node_type: The type of the node. + :param priority: The priority of the node. + :param hardware_state: The state of the node. + :param config_values: Config values. + """ self.node_id: Final[str] = node_id self.name: Final[str] = name self.node_type: Final[NodeType] = node_type diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 0faef627..da4be35e 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -3,16 +3,7 @@ class NodeStateInstructionGreen(object): - """The Node State Instruction class. - - :param _id: The node state instruction id - :param _start_step: The start step of the instruction - :param _end_step: The end step of the instruction - :param _node_id: The id of the associated node - :param _node_pol_type: The pattern of life type - :param _service_name: The service name - :param _state: The state (node or service) - """ + """The Node State Instruction class.""" def __init__( self, @@ -24,6 +15,17 @@ class NodeStateInstructionGreen(object): _service_name, _state, ): + """ + Initialise the Node State Instruction. + + :param _id: The node state instruction id + :param _start_step: The start step of the instruction + :param _end_step: The end step of the instruction + :param _node_id: The id of the associated node + :param _node_pol_type: The pattern of life type + :param _service_name: The service name + :param _state: The state (node or service) + """ self.id = _id self.start_step = _start_step self.end_step = _end_step diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 8308a1c0..f8ce4e74 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -7,20 +7,7 @@ from primaite.common.enums import NodePOLType @dataclass() class NodeStateInstructionRed(object): - """The Node State Instruction class. - - :param _id: The node state instruction id - :param _start_step: The start step of the instruction - :param _end_step: The end step of the instruction - :param _target_node_id: The id of the associated node - :param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE) - :param _pol_type: The pattern of life type - :param pol_protocol: The pattern of life protocol/service affected - :param _pol_state: The state (node or service) - :param _pol_source_node_id: The source node Id (used for initiator type SERVICE) - :param _pol_source_node_service: The source node service (used for initiator type SERVICE) - :param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE) - """ + """The Node State Instruction class.""" def __init__( self, @@ -36,6 +23,21 @@ class NodeStateInstructionRed(object): _pol_source_node_service, _pol_source_node_service_state, ): + """ + Initialise the Node State Instruction for the red agent. + + :param _id: The node state instruction id + :param _start_step: The start step of the instruction + :param _end_step: The end step of the instruction + :param _target_node_id: The id of the associated node + :param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE) + :param _pol_type: The pattern of life type + :param pol_protocol: The pattern of life protocol/service affected + :param _pol_state: The state (node or service) + :param _pol_source_node_id: The source node Id (used for initiator type SERVICE) + :param _pol_source_node_service: The source node service (used for initiator type SERVICE) + :param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE) + """ self.id = _id self.start_step = _start_step self.end_step = _end_step diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index fa289593..13b2d6ad 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -6,15 +6,7 @@ from primaite.nodes.node import Node class PassiveNode(Node): - """The Passive Node class. - - :param node_id: The node id. - :param name: The name of the node. - :param node_type: The type of the node. - :param priority: The priority of the node. - :param hardware_state: The state of the node. - :param config_values: Config values. - """ + """The Passive Node class.""" def __init__( self, @@ -25,6 +17,16 @@ class PassiveNode(Node): hardware_state: HardwareState, config_values: TrainingConfig, ): + """ + Initialise a passive node. + + :param node_id: The node id. + :param name: The name of the node. + :param node_type: The type of the node. + :param priority: The priority of the node. + :param hardware_state: The state of the node. + :param config_values: Config values. + """ # Pass through to Super for now super().__init__(node_id, name, node_type, priority, hardware_state, config_values) diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index db435c7d..7632e944 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -12,18 +12,7 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__) class ServiceNode(ActiveNode): - """ServiceNode class. - - :param node_id: The node ID - :param name: The node name - :param node_type: The node type (enum) - :param priority: The node priority (enum) - :param hardware_state: The node Hardware State - :param ip_address: The node IP address - :param software_state: The node Software State - :param file_system_state: The node file system state - :param config_values: The config values - """ + """ServiceNode class.""" def __init__( self, @@ -37,6 +26,19 @@ class ServiceNode(ActiveNode): file_system_state: FileSystemState, config_values: TrainingConfig, ): + """ + Initialise a Service Node. + + :param node_id: The node ID + :param name: The node name + :param node_type: The node type (enum) + :param priority: The node priority (enum) + :param hardware_state: The node Hardware State + :param ip_address: The node IP address + :param software_state: The node Software State + :param file_system_state: The node file system state + :param config_values: The config values + """ super().__init__( node_id, name, diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py index bfbc9a31..913a06da 100644 --- a/src/primaite/pol/ier.py +++ b/src/primaite/pol/ier.py @@ -6,19 +6,7 @@ Used to represent an information flow from source to destination. class IER(object): - """Information Exchange Requirement class. - - :param _id: The IER id - :param _start_step: The step when this IER should start - :param _end_step: The step when this IER should end - :param _load: The load this IER should put on a link (bps) - :param _protocol: The protocol of this IER - :param _port: The port this IER runs on - :param _source_node_id: The source node ID - :param _dest_node_id: The destination node ID - :param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical) - :param _running: Indicates whether the IER is currently running - """ + """Information Exchange Requirement class.""" def __init__( self, @@ -33,6 +21,20 @@ class IER(object): _mission_criticality, _running=False, ): + """ + Initialise an Information Exchange Request. + + :param _id: The IER id + :param _start_step: The step when this IER should start + :param _end_step: The step when this IER should end + :param _load: The load this IER should put on a link (bps) + :param _protocol: The protocol of this IER + :param _port: The port this IER runs on + :param _source_node_id: The source node ID + :param _dest_node_id: The destination node ID + :param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical) + :param _running: Indicates whether the IER is currently running + """ self.id = _id self.start_step = _start_step self.end_step = _end_step diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 4aa8476a..1bfb7403 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -19,7 +19,8 @@ _LOGGER = getLogger(__name__) class PrimaiteSession: - """The PrimaiteSession class. + """ + The PrimaiteSession class. Provides a single learning and evaluation entry point for all training and lay down configurations. """ @@ -29,7 +30,8 @@ class PrimaiteSession: training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], ): - """The PrimaiteSession constructor. + """ + The PrimaiteSession constructor. :param training_config_path: The training config path. :param lay_down_config_path: The lay down config path. diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 21d4ee05..a74ef4f9 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -10,7 +10,8 @@ class Transaction(object): """Transaction class.""" def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int): - """Transaction constructor. + """ + Transaction constructor. :param agent_identifier: An identifier for the agent in use :param episode_number: The episode number diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 939ebdb5..5852a84d 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -12,7 +12,8 @@ _LOGGER: Logger = getLogger(__name__) class SessionOutputWriter: - """A session output writer class. + """ + A session output writer class. Is used to write session outputs to csv file. """ @@ -28,6 +29,18 @@ class SessionOutputWriter: transaction_writer: bool = False, learning_session: bool = True, ): + """ + Initialise the Session Output Writer. + + :param env: PrimAITE gym environment. + :type env: Primaite + :param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent. + If `false` it will output the average reward per episode, defaults to False + :type transaction_writer: bool, optional + :param learning_session: Set to `true` to indicate that the current session is a training session. This + determines the name of the folder which contains the final output csv. Defaults to True + :type learning_session: bool, optional + """ self._env = env self.transaction_writer = transaction_writer self.learning_session = learning_session