#2912 - Mid-day commit. Actions moving across from actions.py to game.agent.actions

This commit is contained in:
Charlie Crane
2024-10-17 12:22:30 +01:00
parent 861cfe2c0a
commit cd30e2d084
9 changed files with 743 additions and 244 deletions

View File

@@ -10,13 +10,12 @@ AbstractAction. The ActionManager is responsible for:
ensures that requests conform to the simulator's request format.
"""
from abc import abstractmethod
from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
from primaite import getLogger
from primaite.game.agent.actions.manager import ActionManager
from primaite.game.agent.actions.manager import AbstractAction
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
from primaite.game.agent.actions.service import NodeServiceAbstractAction
from primaite.interface.request import RequestFormat
@@ -1238,4 +1237,3 @@ class RansomwareLaunchC2ServerAction(AbstractAction):
return ["do_nothing"]
# This action currently doesn't require any further configuration options.
return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"]

View File

@@ -0,0 +1,27 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.actions.manager import ActionManager
from primaite.game.agent.actions.service import (
NodeServiceDisableAction,
NodeServiceEnableAction,
NodeServiceFixAction,
NodeServicePauseAction,
NodeServiceRestartAction,
NodeServiceResumeAction,
NodeServiceScanAction,
NodeServiceStartAction,
NodeServiceStopAction,
)
__all__ = (
"NodeServiceDisableAction",
"NodeServiceEnableAction",
"NodeServiceFixAction",
"NodeServicePauseAction",
"NodeServiceRestartAction",
"NodeServiceResumeAction",
"NodeServiceScanAction",
"NodeServiceStartAction",
"NodeServiceStopAction",
"ActionManager",
)

View File

@@ -0,0 +1,170 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, List, Literal
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction
from primaite.game.game import _LOGGER
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
class ACLRuleOptions(BaseModel):
"""Validator for ACL_ADD_RULE options."""
target_router: str
"""On which router to add the rule, must be specified."""
position: int
"""At what position to add the rule, must be specified."""
permission: Literal[1, 2]
"""Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY."""
source_ip_id: int = Field(default=1, ge=1)
"""Rule source IP address. By default, all ip addresses."""
source_wildcard_id: int = Field(default=0, ge=0)
"""Rule source IP wildcard. By default, use the wildcard at index 0 from action manager."""
source_port_id: int = Field(default=1, ge=1)
"""Rule source port. By default, all source ports."""
dest_ip_id: int = Field(default=1, ge=1)
"""Rule destination IP address. By default, all ip addresses."""
dest_wildcard_id: int = Field(default=0, ge=0)
"""Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager."""
dest_port_id: int = Field(default=1, ge=1)
"""Rule destination port. By default, all destination ports."""
protocol_id: int = Field(default=1, ge=1)
"""Rule protocol. By default, all protocols."""
@field_validator(
"source_ip_id",
"source_port_id",
"source_wildcard_id",
"dest_ip_id",
"dest_port_id",
"dest_wildcard_id",
"protocol_id",
mode="before",
)
@classmethod
def not_none(cls, v: str, info: ValidationInfo) -> int:
"""If None is passed, use the default value instead."""
if v is None:
return cls.model_fields[info.field_name].default
return v
def __init__(
self,
manager: "ActionManager",
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
"""Init method for RouterACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
:type num_ips: int
:param num_ports: Number of ports in the simulation.
:type num_ports: int
:param num_protocols: Number of protocols in the simulation.
:type num_protocols: int
"""
super().__init__(manager=manager)
num_permissions = 3
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_id": num_ips,
"dest_ip_id": num_ips,
"source_port_id": num_ports,
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
def form_request(
self,
target_router: str,
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
# Validate incoming data.
parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
target_router=target_router,
position=position,
permission=permission,
source_ip_id=source_ip_id,
source_wildcard_id=source_wildcard_id,
dest_ip_id=dest_ip_id,
dest_wildcard_id=dest_wildcard_id,
source_port_id=source_port_id,
dest_port_id=dest_port_id,
protocol_id=protocol_id,
)
if parsed_options.permission == 1:
permission_str = "PERMIT"
elif parsed_options.permission == 2:
permission_str = "DENY"
else:
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if parsed_options.protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if parsed_options.source_ip_id == 1:
src_ip = "ALL"
else:
src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id)
if parsed_options.source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if parsed_options.dest_ip_id == 1:
dst_ip = "ALL"
else:
dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
if parsed_options.dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
"network",
"node",
target_router,
"acl",
"add_rule",
permission_str,
protocol,
str(src_ip),
src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
dst_port,
position,
]

View File

@@ -0,0 +1,64 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, Dict
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
class NodeApplicationAbstractAction(AbstractAction):
"""
Base class for application actions.
Any action which applies to an application and uses node_id and application_id as its only two parameters can
inherit from this base class.
"""
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
application_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.application_name is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "application", config.application_name, cls.verb]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
"""Action which executes an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
verb: str = "execute"
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
"""Action which scans an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
verb: str = "scan"
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
"""Action which closes an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
verb: str = "close"
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
"""Action which fixes an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
verb: str = "fix"
class NodeApplicationInstallAction(AbstractAction):
"""Action which installs an application."""
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
verb: str = "install"

View File

@@ -0,0 +1,79 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
class NodeFileAbstractAction(AbstractAction):
"""Abstract base class for file actions.
Any action which applies to a file and uses node_name, folder_name, and file_name as its only three parameters can inherit
from this base class.
"""
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
folder_name: str
file_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
"folder",
config.folder_name,
"file",
config.file_name,
cls.verb,
]
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
"""Action which creates a new file in a given folder."""
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
verb: str = "create"
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
"""Action which scans a file."""
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
verb: str = "scan"
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
"""Action which deletes a file."""
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
verb: str = "delete"
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
"""Action which restores a file."""
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
verb: str = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
"""Action which corrupts a file."""
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
verb: str = "corrupt"
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
"""Action which increases a file's access count."""
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
verb: str = "access"

View File

@@ -0,0 +1,65 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, Dict
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
class NodeFolderAbstractAction(AbstractAction):
"""
Base class for folder actions.
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
this base class.
"""
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
folder_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, node_id: int, folder_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = cls.manager.get_node_name_by_idx(node_id)
folder_name = cls.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
if node_name is None or folder_name is None:
return ["do_nothing"]
return ["network", "node", node_name, "file_system", "folder", folder_name, cls.verb]
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
"""Action which scans a folder."""
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
verb: str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
"""Action which checks the hash of a folder."""
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
verb: str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
"""Action which repairs a folder."""
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
verb: str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
"""Action which restores a folder."""
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
verb: str = "restore"
class NodeFolderCreateAction(AbstractAction, identifier="node_folder_create"):
"""Action which creates a new folder."""
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
verb: str = "create"

View File

@@ -1,3 +1,4 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""yaml example
agents:
@@ -10,20 +11,22 @@ agents:
action_map:
"""
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestFormat
from __future__ import annotations
from gymnasium import spaces
import itertools
from typing import Any, ClassVar, Dict, List, Literal, Tuple, Type
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict
from primaite.game.game import _LOGGER, PrimaiteGame
from primaite.interface.request import RequestFormat
class AbstractAction(BaseModel):
"""Base class for actions."""
# notes:
# we actually don't need to hold any state in actions, so there's no need to define any __init__ logic.
# all the init methods in the old actions are just used for holding a verb and shape, which are not really used.
@@ -48,6 +51,7 @@ class AbstractAction(BaseModel):
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
class DoNothingAction(AbstractAction):
class ConfigSchema(AbstractAction.ConfigSchema):
type: Literal["do_nothing"] = "do_nothing"
@@ -55,6 +59,7 @@ class DoNothingAction(AbstractAction):
def form_request(self, options: ConfigSchema) -> RequestFormat:
return ["do_nothing"]
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -131,53 +136,53 @@ class ActionManager:
"""
# Populate lists of apps, services, files, folders, etc on nodes.
for node in nodes:
app_list = [a["application_name"] for a in node.get("applications", [])]
while len(app_list) < max_applications_per_node:
app_list.append(None)
self.application_names.append(app_list)
# for node in nodes:
# app_list = [a["application_name"] for a in node.get("applications", [])]
# while len(app_list) < max_applications_per_node:
# app_list.append(None)
# self.application_names.append(app_list)
svc_list = [s["service_name"] for s in node.get("services", [])]
while len(svc_list) < max_services_per_node:
svc_list.append(None)
self.service_names.append(svc_list)
# svc_list = [s["service_name"] for s in node.get("services", [])]
# while len(svc_list) < max_services_per_node:
# svc_list.append(None)
# self.service_names.append(svc_list)
folder_list = [f["folder_name"] for f in node.get("folders", [])]
while len(folder_list) < max_folders_per_node:
folder_list.append(None)
self.folder_names.append(folder_list)
# folder_list = [f["folder_name"] for f in node.get("folders", [])]
# while len(folder_list) < max_folders_per_node:
# folder_list.append(None)
# self.folder_names.append(folder_list)
file_sublist = []
for folder in node.get("folders", [{"files": []}]):
file_list = [f["file_name"] for f in folder.get("files", [])]
while len(file_list) < max_files_per_folder:
file_list.append(None)
file_sublist.append(file_list)
while len(file_sublist) < max_folders_per_node:
file_sublist.append([None] * max_files_per_folder)
self.file_names.append(file_sublist)
self.protocols: List[str] = protocols
self.ports: List[str] = ports
# file_sublist = []
# for folder in node.get("folders", [{"files": []}]):
# file_list = [f["file_name"] for f in folder.get("files", [])]
# while len(file_list) < max_files_per_folder:
# file_list.append(None)
# file_sublist.append(file_list)
# while len(file_sublist) < max_folders_per_node:
# file_sublist.append([None] * max_files_per_folder)
# self.file_names.append(file_sublist)
# self.protocols: List[str] = protocols
# self.ports: List[str] = ports
self.ip_address_list: List[str] = ip_list
self.wildcard_list: List[str] = wildcard_list
if self.wildcard_list == []:
self.wildcard_list = ["NONE"]
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
"num_nodes": len(self.node_names),
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_applications": max_applications_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols),
"num_ports": len(self.protocols),
"num_ips": len(self.ip_address_list),
"max_acl_rules": max_acl_rules,
"max_nics_per_node": max_nics_per_node,
}
# self.ip_address_list: List[str] = ip_list
# self.wildcard_list: List[str] = wildcard_list
# if self.wildcard_list == []:
# self.wildcard_list = ["NONE"]
# # action_args are settings which are applied to the action space as a whole.
# global_action_args = {
# "num_nodes": len(self.node_names),
# "num_folders": max_folders_per_node,
# "num_files": max_files_per_folder,
# "num_services": max_services_per_node,
# "num_applications": max_applications_per_node,
# "num_nics": max_nics_per_node,
# "num_acl_rules": max_acl_rules,
# "num_protocols": len(self.protocols),
# "num_ports": len(self.protocols),
# "num_ips": len(self.ip_address_list),
# "max_acl_rules": max_acl_rules,
# "max_nics_per_node": max_nics_per_node,
# }
self.actions: Dict[str, AbstractAction] = {}
for act_spec in actions:
# each action is provided into the action space config like this:
@@ -260,191 +265,191 @@ class ActionManager:
"""Return the gymnasium action space for this agent."""
return spaces.Discrete(len(self.action_map))
def get_node_name_by_idx(self, node_idx: int) -> str:
"""
Get the node name corresponding to the given index.
# def get_node_name_by_idx(self, node_idx: int) -> str:
# """
# Get the node name corresponding to the given index.
:param node_idx: The index of the node to retrieve.
:type node_idx: int
:return: The node hostname.
:rtype: str
"""
if not node_idx < len(self.node_names):
msg = (
f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
f"has {len(self.node_names)} nodes."
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.node_names[node_idx]
# :param node_idx: The index of the node to retrieve.
# :type node_idx: int
# :return: The node hostname.
# :rtype: str
# """
# if not node_idx < len(self.node_names):
# msg = (
# f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
# f"has {len(self.node_names)} nodes."
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.node_names[node_idx]
def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
"""
Get the folder name corresponding to the given node and folder indices.
# def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
# """
# Get the folder name corresponding to the given node and folder indices.
:param node_idx: The index of the node.
:type node_idx: int
:param folder_idx: The index of the folder on the node.
:type folder_idx: int
:return: The name of the folder. Or None if the node has fewer folders than the given index.
:rtype: Optional[str]
"""
if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
f" is out of range for its action space. Folder on each node: {self.folder_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.folder_names[node_idx][folder_idx]
# :param node_idx: The index of the node.
# :type node_idx: int
# :param folder_idx: The index of the folder on the node.
# :type folder_idx: int
# :return: The name of the folder. Or None if the node has fewer folders than the given index.
# :rtype: Optional[str]
# """
# if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
# msg = (
# f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
# f" is out of range for its action space. Folder on each node: {self.folder_names}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.folder_names[node_idx][folder_idx]
def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
"""Get the file name corresponding to the given node, folder, and file indices.
# def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
# """Get the file name corresponding to the given node, folder, and file indices.
:param node_idx: The index of the node.
:type node_idx: int
:param folder_idx: The index of the folder on the node.
:type folder_idx: int
:param file_idx: The index of the file in the folder.
:type file_idx: int
:return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has
fewer files than the given index.
:rtype: Optional[str]
"""
if (
node_idx >= len(self.file_names)
or folder_idx >= len(self.file_names[node_idx])
or file_idx >= len(self.file_names[node_idx][folder_idx])
):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
f" but this is out of range for its action space. Files on each node: {self.file_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.file_names[node_idx][folder_idx][file_idx]
# :param node_idx: The index of the node.
# :type node_idx: int
# :param folder_idx: The index of the folder on the node.
# :type folder_idx: int
# :param file_idx: The index of the file in the folder.
# :type file_idx: int
# :return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has
# fewer files than the given index.
# :rtype: Optional[str]
# """
# if (
# node_idx >= len(self.file_names)
# or folder_idx >= len(self.file_names[node_idx])
# or file_idx >= len(self.file_names[node_idx][folder_idx])
# ):
# msg = (
# f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
# f" but this is out of range for its action space. Files on each node: {self.file_names}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.file_names[node_idx][folder_idx][file_idx]
def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
"""Get the service name corresponding to the given node and service indices.
# def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
# """Get the service name corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param service_idx: The index of the service on the node.
:type service_idx: int
:return: The name of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
f" is out of range for its action space. Services on each node: {self.service_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.service_names[node_idx][service_idx]
# :param node_idx: The index of the node.
# :type node_idx: int
# :param service_idx: The index of the service on the node.
# :type service_idx: int
# :return: The name of the service. Or None if the node has fewer services than the given index.
# :rtype: Optional[str]
# """
# if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
# msg = (
# f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
# f" is out of range for its action space. Services on each node: {self.service_names}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.service_names[node_idx][service_idx]
def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
"""Get the application name corresponding to the given node and service indices.
# def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
# """Get the application name corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param application_idx: The index of the service on the node.
:type application_idx: int
:return: The name of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
f"this is out of range for its action space. Applications on each node: {self.application_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.application_names[node_idx][application_idx]
# :param node_idx: The index of the node.
# :type node_idx: int
# :param application_idx: The index of the service on the node.
# :type application_idx: int
# :return: The name of the service. Or None if the node has fewer services than the given index.
# :rtype: Optional[str]
# """
# if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
# msg = (
# f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
# f"this is out of range for its action space. Applications on each node: {self.application_names}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.application_names[node_idx][application_idx]
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
"""Get the internet protocol corresponding to the given index.
# def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
# """Get the internet protocol corresponding to the given index.
:param protocol_idx: The index of the protocol to retrieve.
:type protocol_idx: int
:return: The protocol.
:rtype: str
"""
if protocol_idx >= len(self.protocols):
msg = (
f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
f" is out of range for its action space. Protocols: {self.protocols}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.protocols[protocol_idx]
# :param protocol_idx: The index of the protocol to retrieve.
# :type protocol_idx: int
# :return: The protocol.
# :rtype: str
# """
# if protocol_idx >= len(self.protocols):
# msg = (
# f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
# f" is out of range for its action space. Protocols: {self.protocols}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.protocols[protocol_idx]
def get_ip_address_by_idx(self, ip_idx: int) -> str:
"""
Get the IP address corresponding to the given index.
# def get_ip_address_by_idx(self, ip_idx: int) -> str:
# """
# Get the IP address corresponding to the given index.
:param ip_idx: The index of the IP address to retrieve.
:type ip_idx: int
:return: The IP address.
:rtype: str
"""
if ip_idx >= len(self.ip_address_list):
msg = (
f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
f" is out of range for its action space. IP address list: {self.ip_address_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.ip_address_list[ip_idx]
# :param ip_idx: The index of the IP address to retrieve.
# :type ip_idx: int
# :return: The IP address.
# :rtype: str
# """
# if ip_idx >= len(self.ip_address_list):
# msg = (
# f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
# f" is out of range for its action space. IP address list: {self.ip_address_list}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.ip_address_list[ip_idx]
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
"""
Get the IP wildcard corresponding to the given index.
# def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
# """
# Get the IP wildcard corresponding to the given index.
:param ip_idx: The index of the IP wildcard to retrieve.
:type ip_idx: int
:return: The wildcard address.
:rtype: str
"""
if wildcard_idx >= len(self.wildcard_list):
msg = (
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.wildcard_list[wildcard_idx]
# :param ip_idx: The index of the IP wildcard to retrieve.
# :type ip_idx: int
# :return: The wildcard address.
# :rtype: str
# """
# if wildcard_idx >= len(self.wildcard_list):
# msg = (
# f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
# f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.wildcard_list[wildcard_idx]
def get_port_by_idx(self, port_idx: int) -> str:
"""
Get the port corresponding to the given index.
# def get_port_by_idx(self, port_idx: int) -> str:
# """
# Get the port corresponding to the given index.
:param port_idx: The index of the port to retrieve.
:type port_idx: int
:return: The port.
:rtype: str
"""
if port_idx >= len(self.ports):
msg = (
f"Error: agent attempted to perform an action on port {port_idx} but this"
f" is out of range for its action space. Port list: {self.ip_address_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.ports[port_idx]
# :param port_idx: The index of the port to retrieve.
# :type port_idx: int
# :return: The port.
# :rtype: str
# """
# if port_idx >= len(self.ports):
# msg = (
# f"Error: agent attempted to perform an action on port {port_idx} but this"
# f" is out of range for its action space. Port list: {self.ip_address_list}"
# )
# _LOGGER.error(msg)
# raise RuntimeError(msg)
# return self.ports[port_idx]
def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
"""
Get the NIC number corresponding to the given node and NIC indices.
# def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
# """
# Get the NIC number corresponding to the given node and NIC indices.
:param node_idx: The index of the node.
:type node_idx: int
:param nic_idx: The index of the NIC on the node.
:type nic_idx: int
:return: The NIC number.
:rtype: int
"""
return nic_idx + 1
# :param node_idx: The index of the node.
# :type node_idx: int
# :param nic_idx: The index of the NIC on the node.
# :type nic_idx: int
# :return: The NIC number.
# :rtype: int
# """
# return nic_idx + 1
@classmethod
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":

View File

@@ -0,0 +1,52 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, Dict
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
class NodeAbstractAction(AbstractAction):
"""
Abstract base class for node actions.
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
"""
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.node_name, cls.verb]
class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"):
"""Action which scans a node's OS."""
class ConfigSchema(NodeAbstractAction.ConfigSchema):
verb: str = "scan"
class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"):
"""Action which shuts down a node."""
class ConfigSchema(NodeAbstractAction.ConfigSchema):
verb: str = "shutdown"
class NodeStartupAction(NodeAbstractAction, identifier="node_startup"):
"""Action which starts up a node."""
class ConfigSchema(NodeAbstractAction.ConfigSchema):
verb: str = "startup"
class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
"""Action which resets a node."""
class ConfigSchema(NodeAbstractAction.ConfigSchema):
verb: str = "reset"

View File

@@ -1,7 +1,10 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
class NodeServiceAbstractAction(AbstractAction):
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
@@ -14,29 +17,65 @@ class NodeServiceAbstractAction(AbstractAction):
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.node_name, "service", config.service_name, cls.verb]
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
"""Action which scans a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction, identifier=...):
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
"""Action which stops a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
"""Action which starts a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
"""Action which pauses a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
"""Action which resumes a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
"""Action which restarts a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
"""Action which disables a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
"""Action which enables a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "enable"
class NodeServiceFixAction(NodeServiceAbstractAction):
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
"""Action which fixes a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
verb: str = "fix"