Apply suggestions from code review

This commit is contained in:
Marek Wolan
2023-11-27 13:28:11 +00:00
parent 43fee23600
commit 89cbc08352
5 changed files with 47 additions and 38 deletions

View File

@@ -82,7 +82,7 @@ class NodeServiceAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, service_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -98,7 +98,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
self.verb: str = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
@@ -106,7 +106,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
self.verb: str = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
@@ -114,7 +114,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
self.verb: str = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
@@ -122,7 +122,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
self.verb: str = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
@@ -130,7 +130,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
self.verb: str = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
@@ -138,7 +138,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
self.verb: str = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
@@ -146,7 +146,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
self.verb: str = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
@@ -154,7 +154,7 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
self.verb: str = "enable"
class NodeApplicationAbstractAction(AbstractAction):
@@ -169,7 +169,7 @@ class NodeApplicationAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, application_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -185,7 +185,7 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb = "execute"
self.verb: str = "execute"
class NodeFolderAbstractAction(AbstractAction):
@@ -200,7 +200,7 @@ class NodeFolderAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -254,7 +254,7 @@ class NodeFileAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -271,7 +271,7 @@ class NodeFileScanAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "scan"
self.verb: str = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
@@ -279,7 +279,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "checkhash"
self.verb: str = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
@@ -287,7 +287,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "delete"
self.verb: str = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
@@ -295,7 +295,7 @@ class NodeFileRepairAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "repair"
self.verb: str = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
@@ -303,7 +303,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "restore"
self.verb: str = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction):
@@ -311,7 +311,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "corrupt"
self.verb: str = "corrupt"
class NodeAbstractAction(AbstractAction):
@@ -325,7 +325,7 @@ class NodeAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -338,7 +338,7 @@ class NodeOSScanAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "scan"
self.verb: str = "scan"
class NodeShutdownAction(NodeAbstractAction):
@@ -346,7 +346,7 @@ class NodeShutdownAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "shutdown"
self.verb: str = "shutdown"
class NodeStartupAction(NodeAbstractAction):
@@ -354,7 +354,7 @@ class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "startup"
self.verb: str = "startup"
class NodeResetAction(NodeAbstractAction):
@@ -362,7 +362,7 @@ class NodeResetAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "reset"
self.verb: str = "reset"
class NetworkACLAddRuleAction(AbstractAction):
@@ -520,7 +520,7 @@ class NetworkNICAbstractAction(AbstractAction):
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, nic_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -543,7 +543,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "enable"
self.verb: str = "enable"
class NetworkNICDisableAction(NetworkNICAbstractAction):
@@ -551,7 +551,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
self.verb: str = "disable"
class ActionManager:

View File

@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationManager
@@ -23,6 +23,21 @@ class AgentStartSettings(BaseModel):
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AgentStartSettings":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
@@ -180,9 +195,3 @@ class ProxyAgent(AbstractAgent):
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""
...

View File

@@ -24,7 +24,7 @@ class DataManipulationAttackStage(IntEnum):
"Represents the stage of performing a horizontal port scan on the target."
ATTACKING = 3
"Stage of actively attacking the target."
COMPLETE = 4
SUCCEEDED = 4
"Indicates the attack has been successfully completed."
FAILED = 5
"Signifies that the attack has failed."
@@ -134,7 +134,7 @@ class DataManipulationBot(DatabaseClient):
attack_successful = True
if attack_successful:
self.sys_log.info(f"{self.name}: Data manipulation successful")
self.attack_stage = DataManipulationAttackStage.COMPLETE
self.attack_stage = DataManipulationAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Data manipulation failed")
self.attack_stage = DataManipulationAttackStage.FAILED
@@ -163,7 +163,7 @@ class DataManipulationBot(DatabaseClient):
self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success)
if self.repeat and self.attack_stage in (
DataManipulationAttackStage.COMPLETE,
DataManipulationAttackStage.SUCCEEDED,
DataManipulationAttackStage.FAILED,
):
self.attack_stage = DataManipulationAttackStage.NOT_STARTED

View File

@@ -69,5 +69,5 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage in (DataManipulationAttackStage.COMPLETE, DataManipulationAttackStage.FAILED)
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
assert dm_bot.connected