#2912 - scaffold of action changes
This commit is contained in:
@@ -9,91 +9,19 @@ AbstractAction. The ActionManager is responsible for:
|
||||
3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This
|
||||
ensures that requests conform to the simulator's request format.
|
||||
"""
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions.manager import ActionManager
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.game.agent.actions.service import NodeServiceAbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
"""
|
||||
Init method for action.
|
||||
|
||||
All action init functions should accept **kwargs as a way of ignoring extra arguments.
|
||||
|
||||
Since many parameters are defined for the action space as a whole (such as max files per folder, max services
|
||||
per node), we need to pass those options to every action that gets created. To prevent verbosity, these
|
||||
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
|
||||
"""
|
||||
self.name: str = ""
|
||||
"""Human-readable action identifier used for printing, logging, and reporting."""
|
||||
self.shape: Dict[str, int] = {}
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
"""Reference to the ActionManager which created this action. This is used to access the game and simulation
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
"""Action which does nothing. This is here to allow agents to be idle if they choose to."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.name = "DONOTHING"
|
||||
self.shape: Dict[str, int] = {
|
||||
"dummy": 1,
|
||||
}
|
||||
# This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1),
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self, **kwargs) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for service actions.
|
||||
|
||||
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_id)
|
||||
service_name = self.manager.get_service_name_by_idx(node_id, service_id)
|
||||
if node_name is None or service_name is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_name, "service", service_name, self.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
"""Action which scans a service."""
|
||||
@@ -1311,494 +1239,3 @@ class RansomwareLaunchC2ServerAction(AbstractAction):
|
||||
# This action currently doesn't require any further configuration options.
|
||||
return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
act_class_identifiers: Dict[str, type] = {
|
||||
"DONOTHING": DoNothingAction,
|
||||
"NODE_SERVICE_SCAN": NodeServiceScanAction,
|
||||
"NODE_SERVICE_STOP": NodeServiceStopAction,
|
||||
"NODE_SERVICE_START": NodeServiceStartAction,
|
||||
"NODE_SERVICE_PAUSE": NodeServicePauseAction,
|
||||
"NODE_SERVICE_RESUME": NodeServiceResumeAction,
|
||||
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
"NODE_SERVICE_FIX": NodeServiceFixAction,
|
||||
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
|
||||
"NODE_APPLICATION_SCAN": NodeApplicationScanAction,
|
||||
"NODE_APPLICATION_CLOSE": NodeApplicationCloseAction,
|
||||
"NODE_APPLICATION_FIX": NodeApplicationFixAction,
|
||||
"NODE_APPLICATION_INSTALL": NodeApplicationInstallAction,
|
||||
"NODE_APPLICATION_REMOVE": NodeApplicationRemoveAction,
|
||||
"NODE_FILE_SCAN": NodeFileScanAction,
|
||||
"NODE_FILE_CREATE": NodeFileCreateAction,
|
||||
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
"NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
"NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_FILE_CORRUPT": NodeFileCorruptAction,
|
||||
"NODE_FILE_ACCESS": NodeFileAccessAction,
|
||||
"NODE_FOLDER_CREATE": NodeFolderCreateAction,
|
||||
"NODE_FOLDER_SCAN": NodeFolderScanAction,
|
||||
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
"NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
|
||||
"NODE_OS_SCAN": NodeOSScanAction,
|
||||
"NODE_SHUTDOWN": NodeShutdownAction,
|
||||
"NODE_STARTUP": NodeStartupAction,
|
||||
"NODE_RESET": NodeResetAction,
|
||||
"ROUTER_ACL_ADDRULE": RouterACLAddRuleAction,
|
||||
"ROUTER_ACL_REMOVERULE": RouterACLRemoveRuleAction,
|
||||
"FIREWALL_ACL_ADDRULE": FirewallACLAddRuleAction,
|
||||
"FIREWALL_ACL_REMOVERULE": FirewallACLRemoveRuleAction,
|
||||
"HOST_NIC_ENABLE": HostNICEnableAction,
|
||||
"HOST_NIC_DISABLE": HostNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
"NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction,
|
||||
"NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction,
|
||||
"NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction,
|
||||
"CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction,
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction,
|
||||
"CONFIGURE_DOSBOT": ConfigureDoSBotAction,
|
||||
"CONFIGURE_C2_BEACON": ConfigureC2BeaconAction,
|
||||
"C2_SERVER_RANSOMWARE_LAUNCH": RansomwareLaunchC2ServerAction,
|
||||
"C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction,
|
||||
"C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction,
|
||||
"C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction,
|
||||
"NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction,
|
||||
"SSH_TO_REMOTE": NodeSessionsRemoteLoginAction,
|
||||
"SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction,
|
||||
"NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actions: List[Dict], # stores list of actions available to agent
|
||||
nodes: List[Dict], # extra configuration for each node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
max_files_per_folder: int = 2, # allows calculating shape
|
||||
max_services_per_node: int = 2, # allows calculating shape
|
||||
max_applications_per_node: int = 2, # allows calculating shape
|
||||
max_nics_per_node: int = 8, # allows calculating shape
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
|
||||
ip_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
wildcard_list: List[str] = [], # to allow mapping from wildcard index to
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param game: Reference to the game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param actions: List of action specs which should be made available to the agent. The keys of each spec are:
|
||||
'type' and 'options' for passing any options to the action class's init method
|
||||
:type actions: List[dict]
|
||||
:param nodes: Extra configuration for each node.
|
||||
:type nodes: List[Dict]
|
||||
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
|
||||
:type max_folders_per_node: int
|
||||
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
|
||||
:type max_files_per_folder: int
|
||||
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
|
||||
:type max_services_per_node: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
|
||||
:type max_nics_per_node: int
|
||||
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
|
||||
:type max_acl_rules: int
|
||||
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.node_names: List[str] = [n["node_name"] for n in nodes]
|
||||
"""List of node names in this action space. The list order is the mapping between node index and node name."""
|
||||
self.application_names: List[List[str]] = []
|
||||
"""
|
||||
List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name.
|
||||
The first index corresponds to node id, the second index is the app id on that particular node.
|
||||
For instance, self.application_names[0][2] is the name of the third application on the first node.
|
||||
"""
|
||||
self.service_names: List[List[str]] = []
|
||||
"""
|
||||
List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name.
|
||||
The first index corresponds to node id, the second index is the service id on that particular node.
|
||||
For instance, self.service_names[0][2] is the name of the third service on the first node.
|
||||
"""
|
||||
self.folder_names: List[List[str]] = []
|
||||
"""
|
||||
List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder
|
||||
name. The first index corresponds to node id, the second index is the folder id on that particular node.
|
||||
For instance, self.folder_names[0][2] is the name of the third folder on the first node.
|
||||
"""
|
||||
self.file_names: List[List[List[str]]] = []
|
||||
"""
|
||||
List of files per folder per node. The list order gives the three-index mapping between
|
||||
(node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the
|
||||
folder id on that particular node, and the third index is the file id in that particular folder.
|
||||
For instance, self.file_names[0][2][1] is the name of the second file in the third folder on the first node.
|
||||
"""
|
||||
|
||||
# Populate lists of apps, services, files, folders, etc on nodes.
|
||||
for node in nodes:
|
||||
app_list = [a["application_name"] for a in node.get("applications", [])]
|
||||
while len(app_list) < max_applications_per_node:
|
||||
app_list.append(None)
|
||||
self.application_names.append(app_list)
|
||||
|
||||
svc_list = [s["service_name"] for s in node.get("services", [])]
|
||||
while len(svc_list) < max_services_per_node:
|
||||
svc_list.append(None)
|
||||
self.service_names.append(svc_list)
|
||||
|
||||
folder_list = [f["folder_name"] for f in node.get("folders", [])]
|
||||
while len(folder_list) < max_folders_per_node:
|
||||
folder_list.append(None)
|
||||
self.folder_names.append(folder_list)
|
||||
|
||||
file_sublist = []
|
||||
for folder in node.get("folders", [{"files": []}]):
|
||||
file_list = [f["file_name"] for f in folder.get("files", [])]
|
||||
while len(file_list) < max_files_per_folder:
|
||||
file_list.append(None)
|
||||
file_sublist.append(file_list)
|
||||
while len(file_sublist) < max_folders_per_node:
|
||||
file_sublist.append([None] * max_files_per_folder)
|
||||
self.file_names.append(file_sublist)
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str] = ip_list
|
||||
self.wildcard_list: List[str] = wildcard_list
|
||||
if self.wildcard_list == []:
|
||||
self.wildcard_list = ["NONE"]
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(self.node_names),
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_applications": max_applications_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules,
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips": len(self.ip_address_list),
|
||||
"max_acl_rules": max_acl_rules,
|
||||
"max_nics_per_node": max_nics_per_node,
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_spec in actions:
|
||||
# each action is provided into the action space config like this:
|
||||
# - type: ACTION_TYPE
|
||||
# options:
|
||||
# option_1: value1
|
||||
# option_2: value2
|
||||
# where `type` decides which AbstractAction subclass should be used
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get("type")
|
||||
act_options = act_spec.get("options", {})
|
||||
self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
For example :
|
||||
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
|
||||
"""
|
||||
if act_map is None:
|
||||
# raise RuntimeError("Action map must be specified in the config file.")
|
||||
pass
|
||||
else:
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Generate a list of all the possible actions that could be taken.
|
||||
|
||||
This enumerates all actions all combinations of parameters you could choose for those actions. The output
|
||||
of this function is intended to populate the self.action_map parameter in the situation where the user provides
|
||||
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
|
||||
|
||||
The enumeration relies on the Actions' `shape` attribute.
|
||||
|
||||
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
|
||||
An example output could be:
|
||||
{0: ("DONOTHING", {'dummy': 0}),
|
||||
1: ("NODE_OS_SCAN", {'node_id': 0}),
|
||||
2: ("NODE_OS_SCAN", {'node_id': 1}),
|
||||
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
|
||||
... #etc...
|
||||
}
|
||||
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
|
||||
"""
|
||||
all_action_possibilities = []
|
||||
for act_name, action in self.actions.items():
|
||||
param_names = list(action.shape.keys())
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
|
||||
for j in range(len(param_combinations))
|
||||
]
|
||||
)
|
||||
|
||||
return {i: p for i, p in enumerate(all_action_possibilities)}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format."""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
def get_node_name_by_idx(self, node_idx: int) -> str:
|
||||
"""
|
||||
Get the node name corresponding to the given index.
|
||||
|
||||
:param node_idx: The index of the node to retrieve.
|
||||
:type node_idx: int
|
||||
:return: The node hostname.
|
||||
:rtype: str
|
||||
"""
|
||||
if not node_idx < len(self.node_names):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
|
||||
f"has {len(self.node_names)} nodes."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.node_names[node_idx]
|
||||
|
||||
def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
"""
|
||||
Get the folder name corresponding to the given node and folder indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:return: The name of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
|
||||
f" is out of range for its action space. Folder on each node: {self.folder_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.folder_names[node_idx][folder_idx]
|
||||
|
||||
def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
"""Get the file name corresponding to the given node, folder, and file indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:param file_idx: The index of the file in the folder.
|
||||
:type file_idx: int
|
||||
:return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if (
|
||||
node_idx >= len(self.file_names)
|
||||
or folder_idx >= len(self.file_names[node_idx])
|
||||
or file_idx >= len(self.file_names[node_idx][folder_idx])
|
||||
):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
|
||||
f" but this is out of range for its action space. Files on each node: {self.file_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.file_names[node_idx][folder_idx][file_idx]
|
||||
|
||||
def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
"""Get the service name corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param service_idx: The index of the service on the node.
|
||||
:type service_idx: int
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
|
||||
f" is out of range for its action space. Services on each node: {self.service_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.service_names[node_idx][service_idx]
|
||||
|
||||
def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
|
||||
"""Get the application name corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param application_idx: The index of the service on the node.
|
||||
:type application_idx: int
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
|
||||
f"this is out of range for its action space. Applications on each node: {self.application_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.application_names[node_idx][application_idx]
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
"""Get the internet protocol corresponding to the given index.
|
||||
|
||||
:param protocol_idx: The index of the protocol to retrieve.
|
||||
:type protocol_idx: int
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
if protocol_idx >= len(self.protocols):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
|
||||
f" is out of range for its action space. Protocols: {self.protocols}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
"""
|
||||
Get the IP address corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP address to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
if ip_idx >= len(self.ip_address_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
|
||||
f" is out of range for its action space. IP address list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
|
||||
"""
|
||||
Get the IP wildcard corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP wildcard to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The wildcard address.
|
||||
:rtype: str
|
||||
"""
|
||||
if wildcard_idx >= len(self.wildcard_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
|
||||
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.wildcard_list[wildcard_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
|
||||
:param port_idx: The index of the port to retrieve.
|
||||
:type port_idx: int
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
if port_idx >= len(self.ports):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on port {port_idx} but this"
|
||||
f" is out of range for its action space. Port list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
|
||||
"""
|
||||
Get the NIC number corresponding to the given node and NIC indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param nic_idx: The index of the NIC on the node.
|
||||
:type nic_idx: int
|
||||
:return: The NIC number.
|
||||
:rtype: int
|
||||
"""
|
||||
return nic_idx + 1
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
The action space config supports the following three sections:
|
||||
1. ``action_list``
|
||||
``action_list`` contains a list action components which need to be included in the action space.
|
||||
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
|
||||
which will be passed to the action class's __init__ method during initialisation.
|
||||
2. ``action_map``
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param game: The Primaite Game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
if "ip_list" not in cfg["options"]:
|
||||
cfg["options"]["ip_list"] = []
|
||||
|
||||
obj = cls(
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
0
src/primaite/game/agent/actions/__init__.py
Normal file
0
src/primaite/game/agent/actions/__init__.py
Normal file
487
src/primaite/game/agent/actions/manager.py
Normal file
487
src/primaite/game/agent/actions/manager.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""yaml example
|
||||
|
||||
agents:
|
||||
- name: agent_1
|
||||
action_space:
|
||||
actions:
|
||||
- do_nothing
|
||||
- node_service_start
|
||||
- node_service_stop
|
||||
action_map:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.interface.request import RequestFormat
|
||||
from __future__ import annotations
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
import itertools
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Tuple, Type
|
||||
|
||||
class AbstractAction(BaseModel):
|
||||
"""Base class for actions."""
|
||||
# notes:
|
||||
# we actually don't need to hold any state in actions, so there's no need to define any __init__ logic.
|
||||
# all the init methods in the old actions are just used for holding a verb and shape, which are not really used.
|
||||
# the config schema should be used to the actual parameters for formatting the action itself.
|
||||
# (therefore there's no need for creating action instances, just the action class contains logic for converting
|
||||
# CAOS actions to requests for simulator. Similar to the network node adder, that class also doesn't need to be
|
||||
# instantiated.)
|
||||
class ConfigSchema(BaseModel, ABC): # TODO: not sure if this better named something like `Options`
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
|
||||
_registry: ClassVar[Dict[str,Type[AbstractAction]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier:str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create new action under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def form_request(self, config:ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["do_nothing"] = "do_nothing"
|
||||
|
||||
def form_request(self, options:ConfigSchema) -> RequestFormat:
|
||||
return ["do_nothing"]
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actions: List[Dict], # stores list of actions available to agent
|
||||
# nodes: List[Dict], # extra configuration for each node
|
||||
# max_folders_per_node: int = 2, # allows calculating shape
|
||||
# max_files_per_folder: int = 2, # allows calculating shape
|
||||
# max_services_per_node: int = 2, # allows calculating shape
|
||||
# max_applications_per_node: int = 2, # allows calculating shape
|
||||
# max_nics_per_node: int = 8, # allows calculating shape
|
||||
# max_acl_rules: int = 10, # allows calculating shape
|
||||
# protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
# ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
|
||||
# ip_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
# wildcard_list: List[str] = [], # to allow mapping from wildcard index to
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param game: Reference to the game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param actions: List of action specs which should be made available to the agent. The keys of each spec are:
|
||||
'type' and 'options' for passing any options to the action class's init method
|
||||
:type actions: List[dict]
|
||||
:param nodes: Extra configuration for each node.
|
||||
:type nodes: List[Dict]
|
||||
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
|
||||
:type max_folders_per_node: int
|
||||
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
|
||||
:type max_files_per_folder: int
|
||||
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
|
||||
:type max_services_per_node: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
|
||||
:type max_nics_per_node: int
|
||||
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
|
||||
:type max_acl_rules: int
|
||||
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.node_names: List[str] = [n["node_name"] for n in nodes]
|
||||
"""List of node names in this action space. The list order is the mapping between node index and node name."""
|
||||
self.application_names: List[List[str]] = []
|
||||
"""
|
||||
List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name.
|
||||
The first index corresponds to node id, the second index is the app id on that particular node.
|
||||
For instance, self.application_names[0][2] is the name of the third application on the first node.
|
||||
"""
|
||||
self.service_names: List[List[str]] = []
|
||||
"""
|
||||
List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name.
|
||||
The first index corresponds to node id, the second index is the service id on that particular node.
|
||||
For instance, self.service_names[0][2] is the name of the third service on the first node.
|
||||
"""
|
||||
self.folder_names: List[List[str]] = []
|
||||
"""
|
||||
List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder
|
||||
name. The first index corresponds to node id, the second index is the folder id on that particular node.
|
||||
For instance, self.folder_names[0][2] is the name of the third folder on the first node.
|
||||
"""
|
||||
self.file_names: List[List[List[str]]] = []
|
||||
"""
|
||||
List of files per folder per node. The list order gives the three-index mapping between
|
||||
(node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the
|
||||
folder id on that particular node, and the third index is the file id in that particular folder.
|
||||
For instance, self.file_names[0][2][1] is the name of the second file in the third folder on the first node.
|
||||
"""
|
||||
|
||||
# Populate lists of apps, services, files, folders, etc on nodes.
|
||||
for node in nodes:
|
||||
app_list = [a["application_name"] for a in node.get("applications", [])]
|
||||
while len(app_list) < max_applications_per_node:
|
||||
app_list.append(None)
|
||||
self.application_names.append(app_list)
|
||||
|
||||
svc_list = [s["service_name"] for s in node.get("services", [])]
|
||||
while len(svc_list) < max_services_per_node:
|
||||
svc_list.append(None)
|
||||
self.service_names.append(svc_list)
|
||||
|
||||
folder_list = [f["folder_name"] for f in node.get("folders", [])]
|
||||
while len(folder_list) < max_folders_per_node:
|
||||
folder_list.append(None)
|
||||
self.folder_names.append(folder_list)
|
||||
|
||||
file_sublist = []
|
||||
for folder in node.get("folders", [{"files": []}]):
|
||||
file_list = [f["file_name"] for f in folder.get("files", [])]
|
||||
while len(file_list) < max_files_per_folder:
|
||||
file_list.append(None)
|
||||
file_sublist.append(file_list)
|
||||
while len(file_sublist) < max_folders_per_node:
|
||||
file_sublist.append([None] * max_files_per_folder)
|
||||
self.file_names.append(file_sublist)
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str] = ip_list
|
||||
self.wildcard_list: List[str] = wildcard_list
|
||||
if self.wildcard_list == []:
|
||||
self.wildcard_list = ["NONE"]
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(self.node_names),
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_applications": max_applications_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules,
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips": len(self.ip_address_list),
|
||||
"max_acl_rules": max_acl_rules,
|
||||
"max_nics_per_node": max_nics_per_node,
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_spec in actions:
|
||||
# each action is provided into the action space config like this:
|
||||
# - type: ACTION_TYPE
|
||||
# options:
|
||||
# option_1: value1
|
||||
# option_2: value2
|
||||
# where `type` decides which AbstractAction subclass should be used
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get("type")
|
||||
act_options = act_spec.get("options", {})
|
||||
self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
For example :
|
||||
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
|
||||
"""
|
||||
if act_map is None:
|
||||
# raise RuntimeError("Action map must be specified in the config file.")
|
||||
pass
|
||||
else:
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Generate a list of all the possible actions that could be taken.
|
||||
|
||||
This enumerates all actions all combinations of parameters you could choose for those actions. The output
|
||||
of this function is intended to populate the self.action_map parameter in the situation where the user provides
|
||||
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
|
||||
|
||||
The enumeration relies on the Actions' `shape` attribute.
|
||||
|
||||
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
|
||||
An example output could be:
|
||||
{0: ("DONOTHING", {'dummy': 0}),
|
||||
1: ("NODE_OS_SCAN", {'node_id': 0}),
|
||||
2: ("NODE_OS_SCAN", {'node_id': 1}),
|
||||
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
|
||||
... #etc...
|
||||
}
|
||||
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
|
||||
"""
|
||||
all_action_possibilities = []
|
||||
for act_name, action in self.actions.items():
|
||||
param_names = list(action.shape.keys())
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
|
||||
for j in range(len(param_combinations))
|
||||
]
|
||||
)
|
||||
|
||||
return {i: p for i, p in enumerate(all_action_possibilities)}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format."""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
def get_node_name_by_idx(self, node_idx: int) -> str:
|
||||
"""
|
||||
Get the node name corresponding to the given index.
|
||||
|
||||
:param node_idx: The index of the node to retrieve.
|
||||
:type node_idx: int
|
||||
:return: The node hostname.
|
||||
:rtype: str
|
||||
"""
|
||||
if not node_idx < len(self.node_names):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
|
||||
f"has {len(self.node_names)} nodes."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.node_names[node_idx]
|
||||
|
||||
def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
"""
|
||||
Get the folder name corresponding to the given node and folder indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:return: The name of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
|
||||
f" is out of range for its action space. Folder on each node: {self.folder_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.folder_names[node_idx][folder_idx]
|
||||
|
||||
def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
"""Get the file name corresponding to the given node, folder, and file indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:param file_idx: The index of the file in the folder.
|
||||
:type file_idx: int
|
||||
:return: The name of the file. Or None if the node has fewer folders than the given index, or the folder has
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if (
|
||||
node_idx >= len(self.file_names)
|
||||
or folder_idx >= len(self.file_names[node_idx])
|
||||
or file_idx >= len(self.file_names[node_idx][folder_idx])
|
||||
):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
|
||||
f" but this is out of range for its action space. Files on each node: {self.file_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.file_names[node_idx][folder_idx][file_idx]
|
||||
|
||||
def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
"""Get the service name corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param service_idx: The index of the service on the node.
|
||||
:type service_idx: int
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
|
||||
f" is out of range for its action space. Services on each node: {self.service_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.service_names[node_idx][service_idx]
|
||||
|
||||
def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
|
||||
"""Get the application name corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param application_idx: The index of the service on the node.
|
||||
:type application_idx: int
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
|
||||
f"this is out of range for its action space. Applications on each node: {self.application_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.application_names[node_idx][application_idx]
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
"""Get the internet protocol corresponding to the given index.
|
||||
|
||||
:param protocol_idx: The index of the protocol to retrieve.
|
||||
:type protocol_idx: int
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
if protocol_idx >= len(self.protocols):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
|
||||
f" is out of range for its action space. Protocols: {self.protocols}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
"""
|
||||
Get the IP address corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP address to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
if ip_idx >= len(self.ip_address_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
|
||||
f" is out of range for its action space. IP address list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
|
||||
"""
|
||||
Get the IP wildcard corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP wildcard to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The wildcard address.
|
||||
:rtype: str
|
||||
"""
|
||||
if wildcard_idx >= len(self.wildcard_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
|
||||
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.wildcard_list[wildcard_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
|
||||
:param port_idx: The index of the port to retrieve.
|
||||
:type port_idx: int
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
if port_idx >= len(self.ports):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on port {port_idx} but this"
|
||||
f" is out of range for its action space. Port list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
|
||||
"""
|
||||
Get the NIC number corresponding to the given node and NIC indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param nic_idx: The index of the NIC on the node.
|
||||
:type nic_idx: int
|
||||
:return: The NIC number.
|
||||
:rtype: int
|
||||
"""
|
||||
return nic_idx + 1
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
The action space config supports the following three sections:
|
||||
1. ``action_list``
|
||||
``action_list`` contains a list action components which need to be included in the action space.
|
||||
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
|
||||
which will be passed to the action class's __init__ method during initialisation.
|
||||
2. ``action_map``
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param game: The Primaite Game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
if "ip_list" not in cfg["options"]:
|
||||
cfg["options"]["ip_list"] = []
|
||||
|
||||
obj = cls(
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
return obj
|
||||
42
src/primaite/game/agent/actions/service.py
Normal file
42
src/primaite/game/agent/actions/service.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import ClassVar
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
node_name: str
|
||||
service_name: str
|
||||
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config:ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.node_name, "service", config.service_name, cls.verb]
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
|
||||
verb: str = "scan"
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction, identifier=...):
|
||||
verb: str = "stop"
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
verb: str = "start"
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
verb: str = "pause"
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
verb: str = "resume"
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
verb: str = "restart"
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
verb: str = "disable"
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
verb: str = "enable"
|
||||
|
||||
class NodeServiceFixAction(NodeServiceAbstractAction):
|
||||
verb: str = "fix"
|
||||
Reference in New Issue
Block a user