#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 203cc98494
commit 27ca53878a
43 changed files with 284 additions and 896 deletions

View File

@@ -3,13 +3,7 @@
import logging
from typing import Dict, Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -110,9 +104,7 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -130,9 +122,7 @@ class ServiceNode(ActiveNode):
) or software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
@@ -143,9 +133,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -161,9 +149,7 @@ class ServiceNode(ActiveNode):
if service_value.software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "