#2912 - Steps to get test_actions passing the refactored actions. Some linting changes and YAML updates.

This commit is contained in:
Charlie Crane
2024-10-30 18:34:05 +00:00
parent 5cd629a821
commit 844a3a60fa
19 changed files with 480 additions and 320 deletions

View File

@@ -1,5 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import ABC
from typing import Any, ClassVar, Dict, Type
@@ -7,6 +8,7 @@ from pydantic import BaseModel, ConfigDict
from primaite.interface.request import RequestFormat
class AbstractAction(BaseModel):
"""Base class for actions."""
@@ -34,15 +36,17 @@ class AbstractAction(BaseModel):
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
pass
@classmethod
def from_config(cls, config: Dict) -> "AbstractAction":
"""Create an action component from a config dictionary"""
type_id = config.get("type")
"""Create an action component from a config dictionary."""
# set attributes for action based off config dict
# if config["type"] not in cls._registry:
# raise ValueError(f"Invalid action reward type {config['type']}")
if type_id in cls._registry:
return cls(type=type_id, model_config=config)
else:
return []
for attribute, value in config.items():
if not hasattr(cls.ConfigSchema, attribute):
setattr(cls.ConfigSchema, attribute, value)
return cls

View File

@@ -14,6 +14,16 @@ __all__ = (
)
class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"):
"""Base class for ACL actions."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL abstract actions."""
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
@@ -26,9 +36,9 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
source_ip_id: int
source_wildcard_id: int
source_port_id: int
dest_ip_id: int
dest_wildcard_id: int
dest_port_id: int
dst_ip: str
dst_wildcard_id: int
dst_port: int
protocol_name: str
class ACLRuleOptions(BaseModel):
@@ -46,13 +56,13 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"""Rule source IP wildcard. By default, use the wildcard at index 0 from action manager."""
source_port_id: int = Field(default=1, ge=1)
"""Rule source port. By default, all source ports."""
dest_ip_id: int = Field(default=1, ge=1)
dst_ip_id: int = Field(default=1, ge=1)
"""Rule destination IP address. By default, all ip addresses."""
dest_wildcard_id: int = Field(default=0, ge=0)
dst_wildcard_id: int = Field(default=0, ge=0)
"""Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager."""
dest_port_id: int = Field(default=1, ge=1)
dst_port_id: int = Field(default=1, ge=1)
"""Rule destination port. By default, all destination ports."""
protocol_id: int = Field(default=1, ge=1)
protocol_name: str = "ALL"
"""Rule protocol. By default, all protocols."""
@field_validator(
@@ -62,7 +72,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"dest_ip_id",
"dest_port_id",
"dest_wildcard_id",
"protocol_id",
"protocol_name",
mode="before",
)
@classmethod
@@ -82,10 +92,10 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
permission=config.permission,
source_ip_id=config.source_ip_id,
source_wildcard_id=config.source_wildcard_id,
dest_ip_id=config.dest_ip_id,
dest_ip_id=config.dst_ip,
dest_wildcard_id=config.dest_wildcard_id,
source_port_id=config.source_port_id,
dest_port_id=config.dest_port_id,
dest_port_id=config.dst_port_id,
protocol=config.protocol_name,
)
if parsed_options.permission == 1:
@@ -95,10 +105,10 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
# else:
# _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if parsed_options.protocol_id == 1:
if parsed_options.protocol_name == "ALL":
protocol = "ALL"
else:
protocol = cls.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
protocol = parsed_options.protocol_name
# subtract 2 to account for UNUSED=0 and ALL=1.
if parsed_options.source_ip_id == 1:
@@ -120,7 +130,9 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
else:
dst_ip = cls.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
dst_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
dst_ip=config.dst_ip
dst_wildcard = config.dest_wildcard_id
if parsed_options.dest_port_id == 1:
dst_port = "ALL"
@@ -134,14 +146,14 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
config.target_router,
"acl",
"add_rule",
permission_str,
config.permission_str,
protocol,
str(src_ip),
src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
dst_port,
config.src_wildcard,
config.src_port,
str(config.dst_ip),
config.dst_wildcard,
config.dst_port,
config.position,
]
@@ -161,9 +173,27 @@ class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_ru
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
class FirewallACLAddRuleAction(AbstractAction, identifier="firewall_acl_add_rule"):
class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_rule"):
"""Action which adds a rule to a firewall port's ACL."""
max_acl_rules: int
num_ips: int
num_ports: int
num_protocols: int
num_permissions: int = 3
permission: str
class ConfigSchema(ACLAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLAddRuleAction."""
max_acl_rules: int
num_ips: int
num_ports: int
num_protocols: int
num_permissions: int = 3
permission: str
def __init__(
self,
manager: "ActionManager",
@@ -198,92 +228,85 @@ class FirewallACLAddRuleAction(AbstractAction, identifier="firewall_acl_add_rule
"protocol_id": num_protocols,
}
def form_request(
self,
target_firewall_nodename: str,
firewall_port_name: str,
firewall_port_direction: str,
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
) -> List[str]:
@classmethod
def form_request(cls, config:ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if permission == 0:
if config.permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
elif config.permission == 1:
permission_str = "PERMIT"
elif permission == 2:
elif config.permission == 2:
permission_str = "DENY"
# else:
# _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
if config.protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if protocol_id == 1:
if config.protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
pass
if source_ip_id == 0:
if config.source_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif source_ip_id == 1:
elif config.source_ip_id == 1:
src_ip = "ALL"
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_port_id == 0:
if config.source_port_id == 0:
return ["do_nothing"] # invalid formulation
elif source_port_id == 1:
elif config.source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
pass
if dest_ip_id == 0:
if config.dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
elif config.dest_ip_id == 1:
dst_ip = "ALL"
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
pass
if dest_port_id == 0:
if config.dest_port_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_port_id == 1:
elif config.dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
# src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
# dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
pass
return [
"network",
"node",
target_firewall_nodename,
firewall_port_name,
firewall_port_direction,
config.target_firewall_nodename,
config.firewall_port_name,
config.firewall_port_direction,
"acl",
"add_rule",
permission_str,
protocol,
str(src_ip),
src_wildcard,
config.src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
config.dst_wildcard,
dst_port,
position,
config.position,
]

View File

@@ -35,12 +35,21 @@ class NodeApplicationAbstractAction(AbstractAction, identifier="node_application
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.application_name is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "application", config.application_name, cls.verb]
return [
"network",
"node",
config.node_name,
"application",
config.application_name,
cls.model_fields["verb"].default,
]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
"""Action which executes an application."""
verb: str = "execute"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationExecuteAction."""
@@ -50,6 +59,8 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="no
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
"""Action which scans an application."""
verb: str = "scan"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationScanAction."""
@@ -59,6 +70,8 @@ class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
"""Action which closes an application."""
verb: str = "close"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationCloseAction."""
@@ -68,6 +81,8 @@ class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
"""Action which fixes an application."""
verb: str = "fix"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationFixAction."""
@@ -77,18 +92,50 @@ class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_a
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
"""Action which installs an application."""
verb: str = "install"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationInstallAction."""
verb: str = "install"
# TODO: Either changes to application form_request bits, or add that here.
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"software_manager",
"application",
cls.model_fields["verb"].default,
config.application_name,
]
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
"""Action which removes/uninstalls an application."""
verb: str = "uninstall"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationRemoveAction."""
verb: str = "uninstall"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"software_manager",
"application",
cls.model_fields["verb"].default,
config.application_name,
]

View File

@@ -18,7 +18,7 @@ __all__ = (
)
class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_ransomware"):
class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_ransomware_configure"):
"""Action which sets config parameters for a ransomware script on a node."""
class ConfigSchema(AbstractAction.ConfigSchema):
@@ -66,7 +66,7 @@ class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
return ["network", "node", node_name, "application", "DoSBot", "configure", config]
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2"):
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
"""Action which configures a C2 Beacon based on the parameters given."""
class ConfigSchema(AbstractAction.ConfigSchema):
@@ -105,14 +105,14 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2"):
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
config = ConfigureC2BeaconAction._Opts(
c2_server_ip_address=config["c2_server_ip_address"],
keep_alive_frequency=config["keep_alive_frequency"],
masquerade_port=config["masquerade_port"],
masquerade_protocol=config["masquerade_protocol"],
configuration = ConfigureC2BeaconAction._Opts(
c2_server_ip_address=config.c2_server_ip_address,
keep_alive_frequency=config.keep_alive_frequency,
masquerade_port=config.masquerade_port,
masquerade_protocol=config.masquerade_protocol,
)
ConfigureC2BeaconAction._Opts.model_validate(config) # check that options adhere to schema
ConfigureC2BeaconAction._Opts.model_validate(configuration) # check that options adhere to schema
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config.__dict__]
@@ -142,7 +142,7 @@ class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_c
]
class TerminalC2ServerAction(AbstractAction, identifier="terminal_c2_server"):
class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
class _Opts(BaseModel):
@@ -173,7 +173,7 @@ class TerminalC2ServerAction(AbstractAction, identifier="terminal_c2_server"):
return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model]
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="ransomware_launch"):
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
class ConfigSchema(AbstractAction.ConfigSchema):
@@ -190,7 +190,7 @@ class RansomwareLaunchC2ServerAction(AbstractAction, identifier="ransomware_laun
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
class ExfiltrationC2ServerAction(AbstractAction, identifier="exfiltration_c2_server"):
class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"):
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
class _Opts(BaseModel):

View File

@@ -44,22 +44,45 @@ class NodeFileAbstractAction(AbstractAction, identifier="node_file_abstract_acti
config.folder_name,
"file",
config.file_name,
cls.verb,
cls.model_fields["verb"].default,
]
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
"""Action which creates a new file in a given folder."""
verb: str = "create"
force: bool = False
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCreateAction."""
verb: str = "create"
force: bool = False
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
cls.model_fields["verb"].default,
"file",
config.folder_name,
config.file_name,
cls.model_fields["force"].default,
]
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
"""Action which scans a file."""
verb: str = "scan"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileScanAction."""
@@ -69,15 +92,35 @@ class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
"""Action which deletes a file."""
verb: str = "delete"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileDeleteAction."""
verb: str = "delete"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
cls.model_fields["verb"].default,
"file",
config.folder_name,
config.file_name,
]
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
"""Action which restores a file."""
verb: str = "restore"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileRestoreAction."""
@@ -87,6 +130,8 @@ class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restor
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
"""Action which corrupts a file."""
verb: str = "corrupt"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCorruptAction."""
@@ -96,7 +141,46 @@ class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrup
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
"""Action which increases a file's access count."""
verb: str = "access"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileAccessAction."""
verb: str = "access"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
cls.model_fields["verb"].default,
config.folder_name,
config.file_name,
]
class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"):
"""Action which checks the hash of a file."""
verb: str = "checkhash"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCheckhashAction."""
verb: str = "checkhash"
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
"""Action which repairs a file"""
verb: str = "repair"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration Schema for NodeFileRepairAction."""
verb: str = "repair"

View File

@@ -34,12 +34,22 @@ class NodeFolderAbstractAction(AbstractAction, identifier="node_folder_abstract"
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "file_system", "folder", config.folder_name, cls.verb]
return [
"network",
"node",
config.node_name,
"file_system",
"folder",
config.folder_name,
cls.model_fields["verb"].default,
]
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
"""Action which scans a folder."""
verb: str = "scan"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderScanAction."""
@@ -49,6 +59,8 @@ class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_sca
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
"""Action which checks the hash of a folder."""
verb: str = "checkhash"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderCheckhashAction."""
@@ -58,6 +70,8 @@ class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folde
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
"""Action which repairs a folder."""
verb: str = "repair"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderRepairAction."""
@@ -67,16 +81,35 @@ class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_r
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
"""Action which restores a folder."""
verb: str = "restore"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderRestoreAction."""
verb: str = "restore"
class NodeFolderCreateAction(AbstractAction, identifier="node_folder_create"):
class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"):
"""Action which creates a new folder."""
verb: str = "create"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderCreateAction."""
verb: str = "create"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
cls.model_fields["verb"].default,
"folder",
config.folder_name,
]

View File

@@ -13,11 +13,12 @@ class HostNICAbstractAction(AbstractAction, identifier="host_nic_abstract"):
class.
"""
node_name: str
nic_num: str
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for HostNIC actions."""
num_nodes: str
max_nics_per_node: str
node_name: str
nic_num: str
@@ -26,12 +27,21 @@ class HostNICAbstractAction(AbstractAction, identifier="host_nic_abstract"):
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.nic_num is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "network_interface", config.nic_num, cls.verb]
return [
"network",
"node",
config.node_name,
"network_interface",
config.nic_num,
cls.model_fields["verb"].default,
]
class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
"""Action which enables a NIC."""
verb: str = "enable"
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
"""Configuration schema for HostNICEnableAction."""
@@ -41,6 +51,8 @@ class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"):
"""Action which disables a NIC."""
verb: str = "disable"
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
"""Configuration schema for HostNICDisableAction."""

View File

@@ -33,7 +33,8 @@ class DoNothingAction(AbstractAction, identifier="do_nothing"):
type: Literal["do_nothing"] = "do_nothing"
def form_request(self, options: ConfigSchema) -> RequestFormat:
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
@@ -44,17 +45,6 @@ class ActionManager:
def __init__(
self,
actions: List[Dict], # stores list of actions available to agent
# nodes: List[Dict], # extra configuration for each node
# max_folders_per_node: int = 2, # allows calculating shape
# max_files_per_folder: int = 2, # allows calculating shape
# max_services_per_node: int = 2, # allows calculating shape
# max_applications_per_node: int = 2, # allows calculating shape
# max_nics_per_node: int = 8, # allows calculating shape
# max_acl_rules: int = 10, # allows calculating shape
# protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
# ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
# ip_list: List[str] = [], # to allow us to map an index to an ip address.
# wildcard_list: List[str] = [], # to allow mapping from wildcard index to
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
*args,
**kwargs,
@@ -66,114 +56,12 @@ class ActionManager:
:param actions: List of action specs which should be made available to the agent. The keys of each spec are:
'type' and 'options' for passing any options to the action class's init method
:type actions: List[dict]
:param nodes: Extra configuration for each node.
:type nodes: List[Dict]
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
:type max_folders_per_node: int
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
:type max_files_per_folder: int
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
:type max_services_per_node: int
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
:type max_nics_per_node: int
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
:type max_acl_rules: int
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
:type protocols: List[str]
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
:type ports: List[str]
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_list: Optional[List[str]]
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
# self.node_names: List[str] = [n["node_name"] for n in nodes]
"""List of node names in this action space. The list order is the mapping between node index and node name."""
# self.application_names: List[List[str]] = []
"""
List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name.
The first index corresponds to node id, the second index is the app id on that particular node.
For instance, self.application_names[0][2] is the name of the third application on the first node.
"""
# self.service_names: List[List[str]] = []
"""
List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name.
The first index corresponds to node id, the second index is the service id on that particular node.
For instance, self.service_names[0][2] is the name of the third service on the first node.
"""
# self.folder_names: List[List[str]] = []
"""
List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder
name. The first index corresponds to node id, the second index is the folder id on that particular node.
For instance, self.folder_names[0][2] is the name of the third folder on the first node.
"""
# self.file_names: List[List[List[str]]] = []
"""
List of files per folder per node. The list order gives the three-index mapping between
(node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the
folder id on that particular node, and the third index is the file id in that particular folder.
For instance, self.file_names[0][2][1] is the name of the second file in the third folder on the first node.
"""
# Populate lists of apps, services, files, folders, etc on nodes.
# for node in nodes:
# app_list = [a["application_name"] for a in node.get("applications", [])]
# while len(app_list) < max_applications_per_node:
# app_list.append(None)
# self.application_names.append(app_list)
# svc_list = [s["service_name"] for s in node.get("services", [])]
# while len(svc_list) < max_services_per_node:
# svc_list.append(None)
# self.service_names.append(svc_list)
# folder_list = [f["folder_name"] for f in node.get("folders", [])]
# while len(folder_list) < max_folders_per_node:
# folder_list.append(None)
# self.folder_names.append(folder_list)
# file_sublist = []
# for folder in node.get("folders", [{"files": []}]):
# file_list = [f["file_name"] for f in folder.get("files", [])]
# while len(file_list) < max_files_per_folder:
# file_list.append(None)
# file_sublist.append(file_list)
# while len(file_sublist) < max_folders_per_node:
# file_sublist.append([None] * max_files_per_folder)
# self.file_names.append(file_sublist)
# self.protocols: List[str] = protocols
# self.ports: List[str] = ports
# self.ip_address_list: List[str] = ip_list
# self.wildcard_list: List[str] = wildcard_list
# if self.wildcard_list == []:
# self.wildcard_list = ["NONE"]
# # action_args are settings which are applied to the action space as a whole.
# global_action_args = {
# "num_nodes": len(self.node_names),
# "num_folders": max_folders_per_node,
# "num_files": max_files_per_folder,
# "num_services": max_services_per_node,
# "num_applications": max_applications_per_node,
# "num_nics": max_nics_per_node,
# "num_acl_rules": max_acl_rules,
# "num_protocols": len(self.protocols),
# "num_ports": len(self.protocols),
# "num_ips": len(self.ip_address_list),
# "max_acl_rules": max_acl_rules,
# "max_nics_per_node": max_nics_per_node,
# }
self.actions: Dict[str, AbstractAction] = {}
for act_spec in actions:
# each action is provided into the action space config like this:
# - type: ACTION_TYPE
# options:
# option_1: value1
# option_2: value2
# where `type` decides which AbstractAction subclass should be used
# and `options` is an optional dict of options to pass to the init method of the action class
act_type = act_spec.get("type")
# act_options = act_spec.get("options", {}) # Don't need this anymore I think?
self.actions[act_type] = AbstractAction._registry[act_type]
self.action_map: Dict[int, Tuple[str, Dict]] = {}
@@ -237,8 +125,8 @@ class ActionManager:
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_obj = self.actions[action_identifier]
return act_obj.form_request(action_options)
act_obj = self.actions[action_identifier].from_config(config=action_options)
return act_obj.form_request(config=act_obj.ConfigSchema)
@property
def space(self) -> spaces.Space:

View File

@@ -24,12 +24,21 @@ class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstrac
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.target_nodename is None or config.port_id is None:
return ["do_nothing"]
return ["network", "node", config.target_nodename, "network_interface", config.port_id, cls.verb]
return [
"network",
"node",
config.target_nodename,
"network_interface",
config.port_id,
cls.model_fields["verb"].default,
]
class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"):
"""Action which enables are port on a router or a firewall."""
verb: str = "enable"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortEnableAction."""
@@ -39,6 +48,8 @@ class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_por
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
"""Action which disables are port on a router or a firewall."""
verb: str = "disable"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortDisableAction."""

View File

@@ -101,7 +101,8 @@ class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
"""Action which performs an NMAP port scan."""
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
source_node: str
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
show: Optional[bool] = (False,)

View File

@@ -32,12 +32,14 @@ class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstrac
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.node_name, "service", config.service_name, cls.verb]
return ["network", "node", config.node_name, "service", config.service_name, cls.model_fields["verb"].default]
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
"""Action which scans a service."""
verb: str = "scan"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceScanAction."""
@@ -47,6 +49,8 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
"""Action which stops a service."""
verb: str = "stop"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStopAction."""
@@ -56,6 +60,8 @@ class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
"""Action which starts a service."""
verb: str = "start"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStartAction."""
@@ -65,6 +71,8 @@ class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
"""Action which pauses a service."""
verb: str = "pause"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServicePauseAction."""
@@ -74,6 +82,8 @@ class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
"""Action which resumes a service."""
verb: str = "resume"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceResumeAction."""
@@ -83,6 +93,8 @@ class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_servic
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
"""Action which restarts a service."""
verb: str = "restart"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceRestartAction."""
@@ -92,6 +104,8 @@ class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_servi
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
"""Action which disables a service."""
verb: str = "disable"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceDisableAction."""
@@ -101,6 +115,8 @@ class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_servi
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
"""Action which enables a service."""
verb: str = "enable"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceEnableAction."""
@@ -110,6 +126,8 @@ class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_servic
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
"""Action which fixes a service."""
verb: str = "fix"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceFixAction."""

View File

@@ -30,6 +30,9 @@ class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstrac
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
"""Action which performs a remote session login."""
username: str
password: str
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLoginAction."""
@@ -54,7 +57,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_
]
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logout"):
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"):
"""Action which performs a remote session logout."""
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
@@ -68,3 +71,33 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "service", "Terminal", "remote_logoff", config.remote_ip]
class NodeAccountsChangePasswordAction(NodeSessionAbstractAction, identifier="node_accounts_change_password"):
"""Action which changes the password for a user."""
username: str
current_password: str
new_password: str
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeAccountsChangePasswordAction."""
username: str
current_password: str
new_password: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"service",
"UserManager",
"change_password",
config.username,
config.current_password,
cls.new_password,
]

View File

@@ -96,8 +96,8 @@ agents:
action_space:
action_list:
- type: DONOTHING
- type: FIREWALL_ACL_ADDRULE
- type: do_nothing
- type: FIREWALL_ACL_ADDRULE firewall_acl_add_rule
- type: FIREWALL_ACL_REMOVERULE
- type: NETWORK_PORT_DISABLE
- type: NETWORK_PORT_ENABLE

View File

@@ -34,15 +34,16 @@ agents:
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_NETWORK_SERVICE_RECON
- type: node_network_service_recon
action_map:
0:
action: NODE_NMAP_NETWORK_SERVICE_RECON
action: node_network_service_recon
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
target_port: 80
target_protocol: tcp
show: false
reward_function:
reward_components:

View File

@@ -34,13 +34,14 @@ agents:
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_PING_SCAN
- type: node_nmap_ping_scan
action_map:
0:
action: NODE_NMAP_PING_SCAN
action: node_nmap_ping_scan
options:
source_node: client_1
node_name: client_1
target_ip_address: 192.168.1.0/24
show: False
reward_function:
reward_components:

View File

@@ -34,19 +34,21 @@ agents:
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_PORT_SCAN
- type: node_nmap_port_scan
action_map:
0:
action: NODE_NMAP_PORT_SCAN
action: node_nmap_port_scan
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
target_protocol: tcp
target_port:
- 21
- 53
- 80
- 123
- 219
show: false
reward_function:
reward_components:

View File

@@ -419,54 +419,54 @@ def game_and_agent():
install_stuff_to_sim(sim)
actions = [
{"type": "DONOTHING"},
{"type": "NODE_SERVICE_SCAN"},
{"type": "NODE_SERVICE_STOP"},
{"type": "NODE_SERVICE_START"},
{"type": "NODE_SERVICE_PAUSE"},
{"type": "NODE_SERVICE_RESUME"},
{"type": "NODE_SERVICE_RESTART"},
{"type": "NODE_SERVICE_DISABLE"},
{"type": "NODE_SERVICE_ENABLE"},
{"type": "NODE_SERVICE_FIX"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_APPLICATION_SCAN"},
{"type": "NODE_APPLICATION_CLOSE"},
{"type": "NODE_APPLICATION_FIX"},
{"type": "NODE_APPLICATION_INSTALL"},
{"type": "NODE_APPLICATION_REMOVE"},
{"type": "NODE_FILE_CREATE"},
{"type": "NODE_FILE_SCAN"},
{"type": "NODE_FILE_CHECKHASH"},
{"type": "NODE_FILE_DELETE"},
{"type": "NODE_FILE_REPAIR"},
{"type": "NODE_FILE_RESTORE"},
{"type": "NODE_FILE_CORRUPT"},
{"type": "NODE_FILE_ACCESS"},
{"type": "NODE_FOLDER_CREATE"},
{"type": "NODE_FOLDER_SCAN"},
{"type": "NODE_FOLDER_CHECKHASH"},
{"type": "NODE_FOLDER_REPAIR"},
{"type": "NODE_FOLDER_RESTORE"},
{"type": "NODE_OS_SCAN"},
{"type": "NODE_SHUTDOWN"},
{"type": "NODE_STARTUP"},
{"type": "NODE_RESET"},
{"type": "ROUTER_ACL_ADDRULE"},
{"type": "ROUTER_ACL_REMOVERULE"},
{"type": "HOST_NIC_ENABLE"},
{"type": "HOST_NIC_DISABLE"},
{"type": "NETWORK_PORT_ENABLE"},
{"type": "NETWORK_PORT_DISABLE"},
{"type": "CONFIGURE_C2_BEACON"},
{"type": "C2_SERVER_RANSOMWARE_LAUNCH"},
{"type": "C2_SERVER_RANSOMWARE_CONFIGURE"},
{"type": "C2_SERVER_TERMINAL_COMMAND"},
{"type": "C2_SERVER_DATA_EXFILTRATE"},
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
{"type": "SSH_TO_REMOTE"},
{"type": "SESSIONS_REMOTE_LOGOFF"},
{"type": "NODE_SEND_REMOTE_COMMAND"},
{"type": "do_nothing"},
{"type": "node_service_scan"},
{"type": "node_service_stop"},
{"type": "node_service_start"},
{"type": "node_service_pause"},
{"type": "node_service_resume"},
{"type": "node_service_restart"},
{"type": "node_service_disable"},
{"type": "node_service_enable"},
{"type": "node_service_fix"},
{"type": "node_application_execute"},
{"type": "node_application_scan"},
{"type": "node_application_close"},
{"type": "node_application_fix"},
{"type": "node_application_install"},
{"type": "node_application_remove"},
{"type": "node_file_create"},
{"type": "node_file_scan"},
{"type": "node_file_checkhash"},
{"type": "node_file_delete"},
{"type": "node_file_repair"},
{"type": "node_file_restore"},
{"type": "node_file_corrupt"},
{"type": "node_file_access"},
{"type": "node_folder_create"},
{"type": "node_folder_scan"},
{"type": "node_folder_checkhash"},
{"type": "node_folder_repair"},
{"type": "node_folder_restore"},
{"type": "node_os_scan"},
{"type": "node_shutdown"},
{"type": "node_startup"},
{"type": "node_reset"},
{"type": "router_acl_add_rule"},
{"type": "router_acl_remove_rule"},
{"type": "host_nic_enable"},
{"type": "host_nic_disable"},
{"type": "network_port_enable"},
{"type": "network_port_disable"},
{"type": "configure_c2_beacon"},
{"type": "c2_server_ransomware_launch"},
{"type": "c2_server_ransomware_configure"},
{"type": "c2_server_terminal_command"},
{"type": "c2_server_data_exfiltrate"},
{"type": "node_accounts_change_password"},
{"type": "node_session_remote_login"},
{"type": "node_session_remote_logoff"},
{"type": "node_send_remote_command"},
]
action_space = ActionManager(

View File

@@ -35,7 +35,7 @@ def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent])
"""Test that the DoNothingAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
@@ -56,7 +56,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
assert svc.health_state_visible == SoftwareHealthState.UNUSED
# 2: Scan and check that the visible state is now correct
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.GOOD
@@ -67,7 +67,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
assert svc.health_state_visible == SoftwareHealthState.GOOD
# 4: Scan and check that the visible state is now correct
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.COMPROMISED
@@ -88,7 +88,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
svc.health_state_actual = SoftwareHealthState.COMPROMISED
# 2: Apply a patch action
action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0})
action = ("node_service_fix", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
@@ -96,7 +96,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
assert svc.health_state_actual == SoftwareHealthState.FIXING
# 4: perform a few do-nothing steps and check that the service is now in the good state
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.GOOD
@@ -121,7 +121,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
# 2: Add a rule to block client 1 from reaching server 2 on router
action = (
"ROUTER_ACL_ADDRULE",
"router_acl_add_rule",
{
"target_router": "router",
"position": 4, # 4th rule
@@ -130,7 +130,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
"dest_ip_id": 6, # 10.0.2.3 (server_2)
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"protocol_name": "ALL", # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
},
@@ -186,7 +186,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
# 2: Remove rule that allows HTTP traffic across the network
action = (
"ROUTER_ACL_REMOVERULE",
"router_acl_remove_rule",
{
"target_router": "router",
"position": 3, # 4th rule
@@ -219,10 +219,10 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
# 2: Disable the NIC on client_1
action = (
"HOST_NIC_DISABLE",
"host_nic_disable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -250,10 +250,10 @@ def test_host_nic_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAg
# 2: Use action to enable nic
action = (
"HOST_NIC_ENABLE",
"host_nic_enable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -277,11 +277,11 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge
# 2: perform a scan and make sure nothing has changed
action = (
"NODE_FILE_SCAN",
"node_file_scan",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads,
"file_id": 0, # cat.png
"node_name": "client_1", # client_1,
"folder_name": "downloads", # downloads,
"file_name": "cat.png", # cat.png
},
)
agent.store_action(action)
@@ -314,11 +314,11 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
# 2: delete the file
action = (
"NODE_FILE_DELETE",
"node_file_delete",
{
"node_id": 0, # client_1
"folder_id": 0, # downloads
"file_id": 0, # cat.png
"node_name": "client_1", # client_1
"folder_name": "downloads", # downloads
"file_name": "cat.png", # cat.png
},
)
agent.store_action(action)
@@ -334,15 +334,11 @@ def test_node_file_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a file is created."""
game, agent = game_and_agent
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
client_1 = game.simulation.network.get_node_by_hostname("client_1")
action = (
"NODE_FILE_CREATE",
{
"node_id": 0,
"folder_name": "test",
"file_name": "file.txt",
},
"node_file_create",
{"node_name": "client_1", "folder_name": "test", "file_name": "file.txt", "force": "False"},
)
agent.store_action(action)
game.step()
@@ -357,9 +353,9 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
action = (
"NODE_FILE_CREATE",
"node_file_create",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
},
@@ -370,9 +366,9 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
assert client_1.file_system.get_file(folder_name="test", file_name="file.txt").num_access == 0
action = (
"NODE_FILE_ACCESS",
"node_file_access",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
},
@@ -390,9 +386,9 @@ def test_node_folder_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
action = (
"NODE_FOLDER_CREATE",
"node_folder_create",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
},
)
@@ -418,7 +414,7 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
# 2: Disable the NIC on client_1
action = (
"NETWORK_PORT_DISABLE",
"network_port_disable",
{
"target_nodename": "router", # router
"port_id": 1, # port 1
@@ -450,7 +446,7 @@ def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGa
# 2: Use action to enable port
action = (
"NETWORK_PORT_ENABLE",
"network_port_enable",
{
"target_nodename": "router", # router
"port_id": 1, # port 1
@@ -480,7 +476,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
assert browser.health_state_visible == SoftwareHealthState.UNUSED
# 2: Scan and check that the visible state is now correct
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = ("node_application_scan", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.GOOD
@@ -491,7 +487,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
assert browser.health_state_visible == SoftwareHealthState.GOOD
# 4: Scan and check that the visible state is now correct
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = ("node_application_scan", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.COMPROMISED
@@ -512,7 +508,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
browser.health_state_actual = SoftwareHealthState.COMPROMISED
# 2: Apply a fix action
action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0})
action = ("node_application_fix", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
@@ -520,7 +516,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
assert browser.health_state_actual == SoftwareHealthState.FIXING
# 4: perform a few do-nothing steps and check that the application is now in the good state
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.GOOD
@@ -538,7 +534,7 @@ def test_node_application_close_integration(game_and_agent: Tuple[PrimaiteGame,
assert browser.operating_state == ApplicationOperatingState.RUNNING
# 2: Apply a close action
action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0})
action = ("node_application_close", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
@@ -549,7 +545,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
"""Test that the NodeApplicationInstallAction and NodeApplicationRemoveAction can form a request and that
it is accepted by the simulation.
When you initiate a install action, the Application will be installed and configured on the node.
When you initiate an install action, the Application will be installed and configured on the node.
The remove action will uninstall the application from the node."""
game, agent = game_and_agent
@@ -557,13 +553,13 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
assert client_1.software_manager.software.get("DoSBot") is None
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
action = ("node_application_install", {"node_name": "client_1", "application_name": "DoSBot"})
agent.store_action(action)
game.step()
assert client_1.software_manager.software.get("DoSBot") is not None
action = ("NODE_APPLICATION_REMOVE", {"node_id": 0, "application_name": "DoSBot"})
action = ("node_application_remove", {"node_name": "client_1", "application_name": "DoSBot"})
agent.store_action(action)
game.step()

View File

@@ -27,9 +27,9 @@ def test_probabilistic_agent():
action_space = ActionManager(
actions=[
{"type": "DONOTHING"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_DELETE"},
{"type": "do_nothing"},
{"type": "node_application_execute"},
{"type": "node_file_delete"},
],
nodes=[
{
@@ -47,9 +47,15 @@ def test_probabilistic_agent():
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
act_map={
0: {"action": "DONOTHING", "options": {}},
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
0: {"action": "do_nothing", "options": {}},
1: {
"action": "node_application_execute",
"options": {"node_name": "client_1", "application_name": "WebBrowser"},
},
2: {
"action": "node_file_delete",
"options": {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
},
},
)
observation_space = ObservationManager(NestedObservation(components={}))
@@ -70,11 +76,11 @@ def test_probabilistic_agent():
node_file_delete_count = 0
for _ in range(N_TRIALS):
a = pa.get_action(0)
if a == ("DONOTHING", {}):
if a == ("do_nothing", {}):
do_nothing_count += 1
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
elif a == ("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}):
node_application_execute_count += 1
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
elif a == ("node_file_delete", {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}):
node_file_delete_count += 1
else:
raise AssertionError("Probabilistic agent produced an unexpected action.")