Added docstrings to class intialisers

This commit is contained in:
Marek Wolan
2023-07-06 16:08:51 +01:00
parent 2a08d3a2a5
commit 86725064ec
23 changed files with 253 additions and 137 deletions

View File

@@ -9,10 +9,12 @@ class AccessControlList:
"""Access Control List class."""
def __init__(self):
"""Initialise an empty AccessControlList."""
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""Checks for IP address matches.
"""
Checks for IP address matches.
Args:
_rule: The rule being checked

View File

@@ -3,16 +3,18 @@
class ACLRule:
"""Access Control List Rule class.
:param _permission: The permission (ALLOW or DENY)
:param _source_ip: The source IP address
:param _dest_ip: The destination IP address
:param _protocol: The rule protocol
:param _port: The rule port
"""
"""Access Control List Rule class."""
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""
Initialise an ACL Rule.
:param _permission: The permission (ALLOW or DENY)
:param _source_ip: The source IP address
:param _dest_ip: The destination IP address
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission = _permission
self.source_ip = _source_ip
self.dest_ip = _dest_ip

View File

@@ -38,7 +38,8 @@ def get_session_path(session_timestamp: datetime) -> Path:
class AgentSessionABC(ABC):
"""An ABC that manages training and/or evaluation of agents in PrimAITE.
"""
An ABC that manages training and/or evaluation of agents in PrimAITE.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
@@ -46,6 +47,15 @@ class AgentSessionABC(ABC):
@abstractmethod
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise an agent session from config files.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
@@ -289,13 +299,23 @@ class AgentSessionABC(ABC):
class HardCodedAgentSessionABC(AgentSessionABC):
"""An Agent Session ABC for evaluation deterministic agents.
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path)
self._setup()

View File

@@ -23,7 +23,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes):
"""Get blocked green IERs.
"""
Get blocked green IERs.
TODO: Add params and return in docstring.
TODO: Typehint params and return.

View File

@@ -42,6 +42,18 @@ class RLlibAgent(AgentSessionABC):
"""An AgentSession class that implements a Ray RLlib agent."""
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise the RLLib Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "RLLIB")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"

View File

@@ -19,6 +19,18 @@ class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(self, training_config_path, lay_down_config_path):
"""
Initialise the SB3 Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"

View File

@@ -3,7 +3,8 @@ from primaite.agents.utils import get_new_action, transform_action_acl_enum, tra
class RandomAgent(HardCodedAgentSessionABC):
"""A Random Agent.
"""
A Random Agent.
Get a completely random action from the action space.
"""
@@ -13,7 +14,8 @@ class RandomAgent(HardCodedAgentSessionABC):
class DummyAgent(HardCodedAgentSessionABC):
"""A Dummy Agent.
"""
A Dummy Agent.
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
@@ -23,7 +25,8 @@ class DummyAgent(HardCodedAgentSessionABC):
class DoNothingACLAgent(HardCodedAgentSessionABC):
"""A do nothing ACL agent.
"""
A do nothing ACL agent.
A valid ACL action that has no effect; does nothing.
"""
@@ -37,7 +40,8 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
class DoNothingNodeAgent(HardCodedAgentSessionABC):
"""A do nothing Node agent.
"""
A do nothing Node agent.
A valid Node action that has no effect; does nothing.
"""

View File

@@ -3,17 +3,21 @@
class Protocol(object):
"""Protocol class.
:param _name: The protocol name
"""
"""Protocol class."""
def __init__(self, _name):
"""
Initialise a protocol.
:param _name: The name of the protocol
:type _name: str
"""
self.name = _name
self.load = 0 # bps
def get_name(self):
"""Gets the protocol name.
"""
Gets the protocol name.
Returns:
The protocol name
@@ -21,7 +25,8 @@ class Protocol(object):
return self.name
def get_load(self):
"""Gets the protocol load.
"""
Gets the protocol load.
Returns:
The protocol load (bps)
@@ -29,7 +34,8 @@ class Protocol(object):
return self.load
def add_load(self, _load):
"""Adds load to the protocol.
"""
Adds load to the protocol.
Args:
_load: The load to add

View File

@@ -5,14 +5,16 @@ from primaite.common.enums import SoftwareState
class Service(object):
"""Service class.
:param name: The service name.
:param port: The service port.
:param software_state: The service SoftwareState.
"""
"""Service class."""
def __init__(self, name: str, port: str, software_state: SoftwareState):
"""
Initialise a service.
:param name: The service name.
:param port: The service port.
:param software_state: The service SoftwareState.
"""
self.name = name
self.port = port
self.software_state = software_state

View File

@@ -24,7 +24,8 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
def main_training_config_path() -> Path:
"""The path to the example training_config_main.yaml file.
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
@@ -234,7 +235,8 @@ class TrainingConfig:
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""Read in a training config yaml file.
"""
Read in a training config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise
@@ -278,7 +280,8 @@ def convert_legacy_training_config_dict(
action_type: ActionType = ActionType.ANY,
num_steps: int = 256,
) -> Dict[str, Any]:
"""Convert a legacy training config dict to the new format.
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values.
@@ -305,7 +308,8 @@ def convert_legacy_training_config_dict(
def _get_new_key_from_legacy(legacy_key: str) -> str:
"""Maps legacy training config keys to the new format keys.
"""
Maps legacy training config keys to the new format keys.
:param legacy_key: A legacy training config key.
:return: The mapped key.

View File

@@ -25,6 +25,12 @@ class AbstractObservationComponent(ABC):
@abstractmethod
def __init__(self, env: "Primaite"):
"""
Initialise observation component.
:param env: Primaite training environment.
:type env: Primaite
"""
_LOGGER.info(f"Initialising {self} observation component")
self.env: "Primaite" = env
self.space: spaces.Space
@@ -68,6 +74,11 @@ class NodeLinkTable(AbstractObservationComponent):
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""Initialise a NodeLinkTable observation space component.
:param env: Training environment.
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
@@ -192,14 +203,17 @@ class NodeStatuses(AbstractObservationComponent):
node2 serviceN state (one for each service),
...
]
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""
Initialise a NodeStatuses observation component.
:param env: Training environment.
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
@@ -288,14 +302,6 @@ class LinkTrafficLevels(AbstractObservationComponent):
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
defaults to 5
:type quantisation_levels: int, optional
"""
_DATA_TYPE: type = np.int64
@@ -306,6 +312,18 @@ class LinkTrafficLevels(AbstractObservationComponent):
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
"""
Initialise a LinkTrafficLevels observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical
value, defaults to 5
:type quantisation_levels: int, optional
"""
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
@@ -390,6 +408,7 @@ class ObservationsHandler:
}
def __init__(self):
"""Initialise the observation handler."""
self.registered_obs_components: List[AbstractObservationComponent] = []
# internal the observation space (unflattened version of space if flatten=True)

View File

@@ -67,7 +67,8 @@ class Primaite(Env):
session_path: Path,
timestamp_str: str,
):
"""The Primaite constructor.
"""
The Primaite constructor.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.

View File

@@ -6,16 +6,18 @@ from primaite.common.protocol import Protocol
class Link(object):
"""Link class.
:param _id: The IER id
:param _bandwidth: The bandwidth of the link (bps)
:param _source_node_name: The name of the source node
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
"""Link class."""
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
"""
Initialise a Link within the simulated network.
:param _id: The IER id
:param _bandwidth: The bandwidth of the link (bps)
:param _source_node_name: The name of the source node
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
self.id = _id
self.bandwidth = _bandwidth
self.source_node_name = _source_node_name

View File

@@ -11,19 +11,7 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class ActiveNode(Node):
"""Active Node class.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
"""Active Node class."""
def __init__(
self,
@@ -37,6 +25,19 @@ class ActiveNode(Node):
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
"""
Initialise an active node.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
self.ip_address: str = ip_address
# Related to Software

View File

@@ -7,15 +7,7 @@ from primaite.config.training_config import TrainingConfig
class Node:
"""Node class.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
"""Node class."""
def __init__(
self,
@@ -26,6 +18,16 @@ class Node:
hardware_state: HardwareState,
config_values: TrainingConfig,
):
"""
Initialise a node.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
self.node_id: Final[str] = node_id
self.name: Final[str] = name
self.node_type: Final[NodeType] = node_type

View File

@@ -3,16 +3,7 @@
class NodeStateInstructionGreen(object):
"""The Node State Instruction class.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _node_id: The id of the associated node
:param _node_pol_type: The pattern of life type
:param _service_name: The service name
:param _state: The state (node or service)
"""
"""The Node State Instruction class."""
def __init__(
self,
@@ -24,6 +15,17 @@ class NodeStateInstructionGreen(object):
_service_name,
_state,
):
"""
Initialise the Node State Instruction.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _node_id: The id of the associated node
:param _node_pol_type: The pattern of life type
:param _service_name: The service name
:param _state: The state (node or service)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step

View File

@@ -7,20 +7,7 @@ from primaite.common.enums import NodePOLType
@dataclass()
class NodeStateInstructionRed(object):
"""The Node State Instruction class.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _target_node_id: The id of the associated node
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
:param _pol_type: The pattern of life type
:param pol_protocol: The pattern of life protocol/service affected
:param _pol_state: The state (node or service)
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
"""The Node State Instruction class."""
def __init__(
self,
@@ -36,6 +23,21 @@ class NodeStateInstructionRed(object):
_pol_source_node_service,
_pol_source_node_service_state,
):
"""
Initialise the Node State Instruction for the red agent.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _target_node_id: The id of the associated node
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
:param _pol_type: The pattern of life type
:param pol_protocol: The pattern of life protocol/service affected
:param _pol_state: The state (node or service)
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step

View File

@@ -6,15 +6,7 @@ from primaite.nodes.node import Node
class PassiveNode(Node):
"""The Passive Node class.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
"""The Passive Node class."""
def __init__(
self,
@@ -25,6 +17,16 @@ class PassiveNode(Node):
hardware_state: HardwareState,
config_values: TrainingConfig,
):
"""
Initialise a passive node.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
# Pass through to Super for now
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)

View File

@@ -12,18 +12,7 @@ _LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class ServiceNode(ActiveNode):
"""ServiceNode class.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
"""ServiceNode class."""
def __init__(
self,
@@ -37,6 +26,19 @@ class ServiceNode(ActiveNode):
file_system_state: FileSystemState,
config_values: TrainingConfig,
):
"""
Initialise a Service Node.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(
node_id,
name,

View File

@@ -6,19 +6,7 @@ Used to represent an information flow from source to destination.
class IER(object):
"""Information Exchange Requirement class.
:param _id: The IER id
:param _start_step: The step when this IER should start
:param _end_step: The step when this IER should end
:param _load: The load this IER should put on a link (bps)
:param _protocol: The protocol of this IER
:param _port: The port this IER runs on
:param _source_node_id: The source node ID
:param _dest_node_id: The destination node ID
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
"""Information Exchange Requirement class."""
def __init__(
self,
@@ -33,6 +21,20 @@ class IER(object):
_mission_criticality,
_running=False,
):
"""
Initialise an Information Exchange Request.
:param _id: The IER id
:param _start_step: The step when this IER should start
:param _end_step: The step when this IER should end
:param _load: The load this IER should put on a link (bps)
:param _protocol: The protocol of this IER
:param _port: The port this IER runs on
:param _source_node_id: The source node ID
:param _dest_node_id: The destination node ID
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step

View File

@@ -19,7 +19,8 @@ _LOGGER = getLogger(__name__)
class PrimaiteSession:
"""The PrimaiteSession class.
"""
The PrimaiteSession class.
Provides a single learning and evaluation entry point for all training and lay down configurations.
"""
@@ -29,7 +30,8 @@ class PrimaiteSession:
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""The PrimaiteSession constructor.
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.

View File

@@ -10,7 +10,8 @@ class Transaction(object):
"""Transaction class."""
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
"""Transaction constructor.
"""
Transaction constructor.
:param agent_identifier: An identifier for the agent in use
:param episode_number: The episode number

View File

@@ -12,7 +12,8 @@ _LOGGER: Logger = getLogger(__name__)
class SessionOutputWriter:
"""A session output writer class.
"""
A session output writer class.
Is used to write session outputs to csv file.
"""
@@ -28,6 +29,18 @@ class SessionOutputWriter:
transaction_writer: bool = False,
learning_session: bool = True,
):
"""
Initialise the Session Output Writer.
:param env: PrimAITE gym environment.
:type env: Primaite
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
If `false` it will output the average reward per episode, defaults to False
:type transaction_writer: bool, optional
:param learning_session: Set to `true` to indicate that the current session is a training session. This
determines the name of the folder which contains the final output csv. Defaults to True
:type learning_session: bool, optional
"""
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session