#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -3,13 +3,7 @@
import logging
from typing import Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
@@ -44,9 +38,7 @@ class ActiveNode(Node):
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(
node_id, name, node_type, priority, hardware_state, config_values
)
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
self.ip_address: str = ip_address
# Related to Software
self._software_state: SoftwareState = software_state
@@ -87,9 +79,7 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(
self, software_state: SoftwareState
):
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
"""
Sets Software State if the node is not compromised.
@@ -100,9 +90,7 @@ class ActiveNode(Node):
if self._software_state != SoftwareState.COMPROMISED:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = (
self.config_values.os_patching_duration
)
self.patching_count = self.config_values.os_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be changed."
@@ -129,14 +117,10 @@ class ActiveNode(Node):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
@@ -149,9 +133,7 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(
self, file_system_state: FileSystemState
):
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
"""
Sets the file system state (actual and observed) if not in a compromised state.
@@ -168,14 +150,10 @@ class ActiveNode(Node):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
@@ -191,9 +169,7 @@ class ActiveNode(Node):
def start_file_system_scan(self):
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = (
self.config_values.file_system_scanning_limit
)
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self):
"""Updates file system status based on scanning/restore/repair cycle."""
@@ -212,10 +188,7 @@ class ActiveNode(Node):
self.file_system_state_observed = FileSystemState.GOOD
# Scanning updates
if (
self.file_system_scanning == True
and self.file_system_scanning_count < 0
):
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
self.file_system_state_observed = self.file_system_state_actual
self.file_system_scanning = False
self.file_system_scanning_count = 0

View File

@@ -32,9 +32,7 @@ class NodeStateInstructionGreen(object):
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type = _node_pol_type
self.service_name = (
_service_name # Not used when not a service instruction
)
self.service_name = _service_name # Not used when not a service instruction
self.state = _state
def get_start_step(self):

View File

@@ -42,9 +42,7 @@ class NodeStateInstructionRed(object):
self.target_node_id = _target_node_id
self.initiator = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name = (
pol_protocol # Not used when not a service instruction
)
self.service_name = pol_protocol # Not used when not a service instruction
self.state = _pol_state
self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service

View File

@@ -28,9 +28,7 @@ class PassiveNode(Node):
:param config_values: Config values.
"""
# Pass through to Super for now
super().__init__(
node_id, name, node_type, priority, hardware_state, config_values
)
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
@property
def ip_address(self) -> str:

View File

@@ -3,13 +3,7 @@
import logging
from typing import Dict, Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -110,9 +104,7 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -130,9 +122,7 @@ class ServiceNode(ActiveNode):
) or software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
@@ -143,9 +133,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -161,9 +149,7 @@ class ServiceNode(ActiveNode):
if service_value.software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "