#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -3,13 +3,7 @@
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
@@ -44,9 +38,7 @@ class ActiveNode(Node):
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(
|
||||
node_id, name, node_type, priority, hardware_state, config_values
|
||||
)
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
self.ip_address: str = ip_address
|
||||
# Related to Software
|
||||
self._software_state: SoftwareState = software_state
|
||||
@@ -87,9 +79,7 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(
|
||||
self, software_state: SoftwareState
|
||||
):
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
@@ -100,9 +90,7 @@ class ActiveNode(Node):
|
||||
if self._software_state != SoftwareState.COMPROMISED:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = (
|
||||
self.config_values.os_patching_duration
|
||||
)
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be changed."
|
||||
@@ -129,14 +117,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
@@ -149,9 +133,7 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(
|
||||
self, file_system_state: FileSystemState
|
||||
):
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
@@ -168,14 +150,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
@@ -191,9 +169,7 @@ class ActiveNode(Node):
|
||||
def start_file_system_scan(self):
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = (
|
||||
self.config_values.file_system_scanning_limit
|
||||
)
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self):
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
@@ -212,10 +188,7 @@ class ActiveNode(Node):
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
# Scanning updates
|
||||
if (
|
||||
self.file_system_scanning == True
|
||||
and self.file_system_scanning_count < 0
|
||||
):
|
||||
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
|
||||
self.file_system_state_observed = self.file_system_state_actual
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
@@ -32,9 +32,7 @@ class NodeStateInstructionGreen(object):
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = (
|
||||
_service_name # Not used when not a service instruction
|
||||
)
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
|
||||
@@ -42,9 +42,7 @@ class NodeStateInstructionRed(object):
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name = (
|
||||
pol_protocol # Not used when not a service instruction
|
||||
)
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
|
||||
@@ -28,9 +28,7 @@ class PassiveNode(Node):
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
# Pass through to Super for now
|
||||
super().__init__(
|
||||
node_id, name, node_type, priority, hardware_state, config_values
|
||||
)
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
|
||||
@@ -3,13 +3,7 @@
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -110,9 +104,7 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(
|
||||
self, protocol_name: str, software_state: SoftwareState
|
||||
):
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -130,9 +122,7 @@ class ServiceNode(ActiveNode):
|
||||
) or software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
@@ -143,9 +133,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(
|
||||
self, protocol_name: str, software_state: SoftwareState
|
||||
):
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -161,9 +149,7 @@ class ServiceNode(ActiveNode):
|
||||
if service_value.software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
|
||||
Reference in New Issue
Block a user