#2912 - eod commit. Addressing test and lint errors for refactored actions

This commit is contained in:
Charlie Crane
2024-10-21 17:51:55 +01:00
parent a5c7565f0e
commit 11357f87ca
13 changed files with 204 additions and 160 deletions

View File

@@ -1,40 +1,31 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.actions.manager import ActionManager
from primaite.game.agent.actions.service import (
NodeServiceDisableAction,
NodeServiceEnableAction,
NodeServiceFixAction,
NodeServicePauseAction,
NodeServiceRestartAction,
NodeServiceResumeAction,
NodeServiceScanAction,
NodeServiceStartAction,
NodeServiceStopAction,
from primaite.game.agent.actions import (
acl,
application,
config,
file,
folder,
host_nic,
manager,
network,
node,
service,
session,
)
from primaite.game.agent.actions.manager import ActionManager
__all__ = (
"NodeServiceDisableAction",
"NodeServiceEnableAction",
"NodeServiceFixAction",
"NodeServicePauseAction",
"NodeServiceRestartAction",
"NodeServiceResumeAction",
"NodeServiceScanAction",
"NodeServiceStartAction",
"NodeServiceStopAction",
"acl",
"application",
"config",
"file",
"folder",
"host_nic",
"manager",
"network",
"node",
"service",
"session",
"ActionManager",
)
# __all__ = (
# "acl",
# "application",
# "config",
# "file",
# "folder",
# "host_nic",
# "manager",
# "network",
# "node",
# "service",
# )

View File

@@ -4,13 +4,28 @@ from typing import Dict, List, Literal
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction
from primaite.game.game import _LOGGER
from primaite.interface.request import RequestFormat
__all__ = ("RouterACLAddRuleAction", "RouterACLRemoveRuleAction", "FirewallACLAddRuleAction", "FirewallACLRemoveRuleAction")
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for RouterACLAddRuleAction."""
target_router: str
position: int
permission: Literal[1, 2]
source_ip_id: int
source_wildcard_id: int
source_port_id: int
dest_ip_id: int
dest_wildcard_id: int
dest_port_id: int
protocol_id: int
class ACLRuleOptions(BaseModel):
"""Validator for ACL_ADD_RULE options."""
@@ -52,73 +67,31 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
return cls.model_fields[info.field_name].default
return v
def __init__(
self,
manager: "ActionManager",
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
"""Init method for RouterACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
:type num_ips: int
:param num_ports: Number of ports in the simulation.
:type num_ports: int
:param num_protocols: Number of protocols in the simulation.
:type num_protocols: int
"""
super().__init__(manager=manager)
num_permissions = 3
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_id": num_ips,
"dest_ip_id": num_ips,
"source_port_id": num_ports,
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
@classmethod
def form_request(
self,
target_router: str,
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
cls,
config: ConfigSchema
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
# Validate incoming data.
parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
target_router=target_router,
position=position,
permission=permission,
source_ip_id=source_ip_id,
source_wildcard_id=source_wildcard_id,
dest_ip_id=dest_ip_id,
dest_wildcard_id=dest_wildcard_id,
source_port_id=source_port_id,
dest_port_id=dest_port_id,
protocol_id=protocol_id,
target_router=config.target_router,
position=config.position,
permission=config.permission,
source_ip_id=config.source_ip_id,
source_wildcard_id=config.source_wildcard_id,
dest_ip_id=config.dest_ip_id,
dest_wildcard_id=config.dest_wildcard_id,
source_port_id=config.source_port_id,
dest_port_id=config.dest_port_id,
protocol_id=config.protocol_id,
)
if parsed_options.permission == 1:
permission_str = "PERMIT"
elif parsed_options.permission == 2:
permission_str = "DENY"
else:
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
# else:
# _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if parsed_options.protocol_id == 1:
protocol = "ALL"
@@ -246,8 +219,8 @@ class FirewallACLAddRuleAction(AbstractAction, identifier="firewall_acl_add_rule
permission_str = "PERMIT"
elif permission == 2:
permission_str = "DENY"
else:
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
# else:
# _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS

View File

@@ -4,8 +4,17 @@ from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeApplicationExecuteAction",
"NodeApplicationScanAction",
"NodeApplicationCloseAction",
"NodeApplicationFixAction",
"NodeApplicationInstallAction",
"NodeApplicationRemoveAction",
)
class NodeApplicationAbstractAction(AbstractAction):
class NodeApplicationAbstractAction(AbstractAction, identifier="node_application_abstract_action"):
"""
Base class for application actions.
@@ -65,7 +74,7 @@ class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_a
verb: str = "fix"
class NodeApplicationInstallAction(NodeApplicationAbstractAction):
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
"""Action which installs an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
@@ -76,12 +85,10 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction):
# TODO: Either changes to application form_request bits, or add that here.
class NodeApplicationRemoveAction(NodeApplicationAbstractAction):
"""Action which removes/uninstalls an application"""
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
"""Action which removes/uninstalls an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationRemoveAction."""
verb: str = "uninstall"
# TODO: Either changes to application form_request bits, or add that here.

View File

@@ -7,31 +7,39 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationIn
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"ConfigureRansomwareScriptAction",
"ConfigureDoSBotAction",
"ConfigureC2BeaconAction",
"NodeSendRemoteCommandAction",
"TerminalC2ServerAction",
"RansomwareLaunchC2ServerAction",
"ExfiltrationC2ServerAction",
)
class ConfigureRansomwareScriptAction(AbstractAction):
class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_ransomware"):
"""Action which sets config parameters for a ransomware script on a node."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this option."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for ConfigureRansomwareScriptAction."""
model_config = ConfigDict(extra="forbid")
server_ip_address: Optional[str] = None
server_password: Optional[str] = None
payload: Optional[str] = None
node_name: str
server_ip_address: Optional[str]
server_password: Optional[str]
payload: Optional[str]
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
if config.node_name is None:
return ["do_nothing"]
ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", node_name, "application", "RansomwareScript", "configure", config]
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config]
class ConfigureDoSBotAction(AbstractAction):
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
"""Action which sets config parameters for a DoS bot on a node."""
class _Opts(BaseModel):
@@ -58,7 +66,7 @@ class ConfigureDoSBotAction(AbstractAction):
return ["network", "node", node_name, "application", "DoSBot", "configure", config]
class ConfigureC2BeaconAction(AbstractAction):
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2"):
"""Action which configures a C2 Beacon based on the parameters given."""
class ConfigSchema(AbstractAction.ConfigSchema):
@@ -109,28 +117,32 @@ class ConfigureC2BeaconAction(AbstractAction):
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config.__dict__]
class NodeSendRemoteCommandAction(AbstractAction):
class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):
"""Action which sends a terminal command to a remote node via SSH."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NodeSendRemoteCommandAction."""
def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat:
node_name: str
remote_ip: str
command: RequestFormat
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
config.node_name,
"service",
"Terminal",
"send_remote_command",
remote_ip,
{"command": command},
config.remote_ip,
{"command": config.command},
]
class TerminalC2ServerAction(AbstractAction):
class TerminalC2ServerAction(AbstractAction, identifier="terminal_c2_server"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
class _Opts(BaseModel):
@@ -161,11 +173,12 @@ class TerminalC2ServerAction(AbstractAction):
return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model]
class RansomwareLaunchC2ServerAction(AbstractAction):
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="ransomware_launch"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
class ConfigSchema(AbstractAction):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for RansomwareLaunchC2ServerAction."""
node_name: str
@classmethod
@@ -176,7 +189,8 @@ class RansomwareLaunchC2ServerAction(AbstractAction):
# This action currently doesn't require any further configuration options.
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
class ExfiltrationC2ServerAction(AbstractAction):
class ExfiltrationC2ServerAction(AbstractAction, identifier="exfiltration_c2_server"):
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
class _Opts(BaseModel):

View File

@@ -4,8 +4,17 @@ from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeFileCreateAction",
"NodeFileScanAction",
"NodeFileDeleteAction",
"NodeFileRestoreAction",
"NodeFileCorruptAction",
"NodeFileAccessAction",
)
class NodeFileAbstractAction(AbstractAction):
class NodeFileAbstractAction(AbstractAction, identifier="node_file_abstract_action"):
"""Abstract base class for file actions.
Any action which applies to a file and uses node_name, folder_name, and file_name as its

View File

@@ -1,16 +1,23 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, Dict
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeFolderScanAction",
"NodeFolderCheckhashAction",
"NodeFolderRepairAction",
"NodeFolderRestoreAction",
"NodeFolderCreateAction",
)
class NodeFolderAbstractAction(AbstractAction):
class NodeFolderAbstractAction(AbstractAction, identifier="node_folder_abstract"):
"""
Base class for folder actions.
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from
this base class.
"""

View File

@@ -2,8 +2,10 @@
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("HostNICEnableAction", "HostNICDisableAction")
class HostNICAbstractAction(AbstractAction):
class HostNICAbstractAction(AbstractAction, identifier="host_nic_abstract"):
"""
Abstract base class for NIC actions.

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""yaml example
"""yaml example.
agents:
- name: agent_1
@@ -20,7 +20,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict
from primaite.game.game import PrimaiteGame
# from primaite.game.game import PrimaiteGame # TODO: Breaks things
from primaite.interface.request import RequestFormat
# TODO: Make sure that actions are backwards compatible where the old YAML format is used.
@@ -37,6 +37,8 @@ class AbstractAction(BaseModel):
# CAOS actions to requests for simulator. Similar to the network node adder, that class also doesn't need to be
# instantiated.)
class ConfigSchema(BaseModel, ABC): # TODO: not sure if this better named something like `Options`
"""Base configuration schema for Actions."""
model_config = ConfigDict(extra="forbid")
type: str
@@ -54,8 +56,12 @@ class AbstractAction(BaseModel):
return []
class DoNothingAction(AbstractAction):
class DoNothingAction(AbstractAction, identifier="do_nothing"):
"""Do Nothing Action."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for DoNothingAction."""
type: Literal["do_nothing"] = "do_nothing"
def form_request(self, options: ConfigSchema) -> RequestFormat:

View File

@@ -1,18 +1,19 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import ClassVar, Dict, Optional
from pydantic import BaseModel, ConfigDict
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
class NetworkPortAbstractAction(AbstractAction):
"""Base class for Network port actions"""
class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"):
"""Base class for Network port actions."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NetworkPort actions."""
target_nodename: str
port_id: str
@@ -21,7 +22,7 @@ class NetworkPortAbstractAction(AbstractAction):
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.target_nodename is None or config.port_id is None:
if config.target_nodename is None or config.port_id is None:
return ["do_nothing"]
return ["network", "node", config.target_nodename, "network_interface", config.port_id, cls.verb]
@@ -34,11 +35,11 @@ class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_por
verb: str = "enable"
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
"""Action which disables are port on a router or a firewall."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortDisableAction"""
"""Configuration schema for NetworkPortDisableAction."""
verb: str = "disable"

View File

@@ -1,12 +1,13 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, Dict
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("NodeOSScanAction", "NodeShutdownAction", "NodeStartupAction", "NodeResetAction")
class NodeAbstractAction(AbstractAction):
class NodeAbstractAction(AbstractAction, identifier="node_abstract"):
"""
Abstract base class for node actions.

View File

@@ -4,8 +4,20 @@ from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeServiceScanAction",
"NodeServiceStopAction",
"NodeServiceStartAction",
"NodeServicePauseAction",
"NodeServiceResumeAction",
"NodeServiceRestartAction",
"NodeServiceDisableAction",
"NodeServiceEnableAction",
"NodeServiceFixAction",
)
class NodeServiceAbstractAction(AbstractAction):
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
service_name: str
@@ -22,6 +34,8 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_
"""Action which scans a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceScanAction"""
verb: str = "scan"
@@ -29,6 +43,8 @@ class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_
"""Action which stops a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStopAction."""
verb: str = "stop"
@@ -36,6 +52,8 @@ class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service
"""Action which starts a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStartAction."""
verb: str = "start"
@@ -43,6 +61,8 @@ class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service
"""Action which pauses a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServicePauseAction."""
verb: str = "pause"
@@ -50,6 +70,8 @@ class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_servic
"""Action which resumes a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceResumeAction."""
verb: str = "resume"
@@ -57,6 +79,8 @@ class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_servi
"""Action which restarts a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceRestartAction."""
verb: str = "restart"
@@ -64,6 +88,8 @@ class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_servi
"""Action which disables a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceDisableAction."""
verb: str = "disable"
@@ -71,6 +97,8 @@ class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_servic
"""Action which enables a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceEnableAction."""
verb: str = "enable"
@@ -78,4 +106,6 @@ class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_f
"""Action which fixes a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceFixAction."""
verb: str = "fix"

View File

@@ -1,12 +1,13 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("NodeSessionsRemoteLoginAction", "NodeSessionsRemoteLogoutAction")
class NodeSessionAbstractAction(AbstractAction):
class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"):
"""Base class for NodeSession actions."""
class ConfigSchema(AbstractAction.ConfigSchema):
@@ -15,19 +16,20 @@ class NodeSessionAbstractAction(AbstractAction):
node_name: str
remote_ip: str
@abstractmethod
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Abstract method. Should return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
"""Abstract method. Should return the action formatted as a request which
can be ingested by the PrimAITE simulation."""
pass
class NodeSessionsRemoteLoginAction(AbstractAction, identifier="node_session_remote_login"):
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
"""Action which performs a remote session login."""
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLoginAction."""
username: str
password: str
@@ -49,11 +51,12 @@ class NodeSessionsRemoteLoginAction(AbstractAction, identifier="node_session_rem
]
class NodeSessionsRemoteLogoutAction(AbstractAction, identifier="node_session_remote_logout"):
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logout"):
"""Action which performs a remote session logout."""
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
pass
@classmethod

View File

@@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs
__all__ = [
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
"ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",]
# fmt: on