Add network action

This commit is contained in:
Marek Wolan
2023-10-06 10:36:29 +01:00
parent fabd4fd5dd
commit 2a8df074b9
4 changed files with 574 additions and 180 deletions

View File

@@ -131,32 +131,36 @@ game_config:
action_space:
action_list:
- DONOTHING
- NODE_SERVICE_SCAN
- NODE_SERVICE_STOP
# - NODE_SERVICE_START
# - NODE_SERVICE_PAUSE
# - NODE_SERVICE_RESUME
# - NODE_SERVICE_RESTART
# - NODE_SERVICE_DISABLE
# - NODE_SERVICE_ENABLE
# - NODE_FILE_SCAN
# - NODE_FILE_CHECKHASH
# - NODE_FILE_DELETE
# - NODE_FILE_REPAIR
# - NODE_FILE_RESTORE
# - NODE_FOLDER_SCAN
# - NODE_FOLDER_CHECKHASH
# - NODE_FOLDER_REPAIR
# - NODE_FOLDER_RESTORE
# - NODE_OS_SCAN
# - NODE_SHUTDOWN
# - NODE_STARTUP
# - NODE_RESET
# - NETWORK_ACL_ADDRULE
# - NETWORK_ACL_REMOVERULE
# - NETWORK_NIC_ENABLE
- NETWORK_NIC_DISABLE
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:

View File

@@ -1,12 +1,13 @@
import itertools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
import itertools
from primaite.simulator.sim_container import Simulation
from gym import spaces
from primaite.game.session import PrimaiteSession
from primaite.simulator.sim_container import Simulation
class ExecutionDefiniton(ABC):
"""
Converter from actions to simulator requests.
@@ -23,9 +24,8 @@ class ExecutionDefiniton(ABC):
class AbstractAction(ABC):
@abstractmethod
def __init__(self, manager:"ActionManager", **kwargs) -> None:
def __init__(self, manager: "ActionManager", **kwargs) -> None:
"""
Init method for action.
@@ -35,13 +35,12 @@ class AbstractAction(ABC):
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
"""
self.name:str = ""
self.name: str = ""
"""Human-readable action identifier used for printing, logging, and reporting."""
self.shape = (0,)
"""Tuple describing number of options for each parameter of this action. Can be passed to
gym.spaces.MultiDiscrete to form a valid space."""
self.manager:ActionManager = manager
self.shape: Dict[str, int] = {}
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
align with the keyword args of the form_request method."""
self.manager: ActionManager = manager
@abstractmethod
def form_request(self) -> List[str]:
@@ -50,14 +49,20 @@ class AbstractAction(ABC):
class DoNothingAction(AbstractAction):
def __init__(self, manager:"ActionManager", **kwargs) -> None:
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
self.name = "DONOTHING"
self.shape = (1,)
self.shape: Dict[str, int] = {
"dummy": 1,
}
# This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1),
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self) -> List[str]:
return ["do_nothing"]
class NodeServiceAbstractAction(AbstractAction):
"""
Base class for service actions.
@@ -65,211 +70,284 @@ class NodeServiceAbstractAction(AbstractAction):
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
this base class.
"""
@abstractmethod
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Tuple[int] = (num_nodes, num_services)
self.verb:str
def form_request(self, node_id:int, service_id:int) -> List[str]:
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str
def form_request(self, node_id: int, service_id: int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
if node_uuid is None or service_uuid is None:
return ["do_nothing"]
return ['network', 'node', node_uuid, 'services', service_uuid, self.verb]
return ["network", "node", node_uuid, "services", service_uuid, self.verb]
class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = "enable"
class NodeFolderAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager=manager)
self.shape = (num_nodes, num_folders)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
def form_request(self, node_id:int, folder_id:int) -> List[str]:
def form_request(self, node_id: int, folder_id: int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
if node_uuid is None or folder_uuid is None:
return ["do_nothing"]
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, self.verb]
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, self.verb]
class NodeFolderScanAction(NodeFolderAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "scan"
self.verb: str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "checkhash"
self.verb: str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction):
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "repair"
self.verb: str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
self.verb:str = "restore"
self.verb: str = "restore"
class NodeFileAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager:"ActionManager", num_nodes:int, num_folders:int, num_files:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape:Tuple[int] = (num_nodes, num_folders, num_files)
self.verb:str
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str
def form_request(self, node_id:int, folder_id:int, file_id:int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
if node_uuid is None or folder_uuid is None or file_uuid is None:
return ["do_nothing"]
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, 'files', file_uuid, self.verb]
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, "files", file_uuid, self.verb]
class NodeFileScanAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
self.verb = "restore"
class NodeAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Tuple[int] = (num_nodes,)
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str
def form_request(self, node_id:int) -> List[str]:
def form_request(self, node_id: int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
return ["network", "node", node_uuid, self.verb]
class NodeOSScanAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'scan'
self.verb = "scan"
class NodeShutdownAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'shutdown'
self.verb = "shutdown"
class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'start'
self.verb = "start"
class NodeResetAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.verb = 'reset'
self.verb = "reset"
class NetworkACLAddRuleAction(AbstractAction):
def __init__(self,
manager: "ActionManager",
target_router_uuid:str,
max_acl_rules:int,
num_ips:int,
num_ports:int,
num_protocols:int,
**kwargs) -> None:
def __init__(
self,
manager: "ActionManager",
target_router_uuid: str,
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
super().__init__(manager=manager)
num_permissions = 2
self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_ips, num_ips, num_ports, num_ports, num_protocols)
self.target_router_uuid:str = target_router_uuid
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_idx": num_ips,
"dest_ip_idx": num_ips,
"source_port_idx": num_ports,
"dest_port_idx": num_ports,
"protocol_idx": num_protocols,
}
self.target_router_uuid: str = target_router_uuid
def form_request(self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx) -> List[str]:
def form_request(
self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx
) -> List[str]:
protocol = self.manager.get_internet_protocol_by_idx(protocol_idx)
src_ip = self.manager.get_ip_address_by_idx(source_ip_idx)
src_port = self.manager.get_port_by_idx(source_port_idx)
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx)
dst_port = self.manager.get_port_by_idx(dest_port_idx)
return [
'network',
'node',
"network",
"node",
self.target_router_uuid,
'acl',
'add_rule',
"acl",
"add_rule",
permission,
protocol,
src_ip,
src_port,
dst_ip,
dst_port,
position
position,
]
class NetworkACLRemoveRuleAction(AbstractAction):
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
self.target_router_uuid: str = target_router_uuid
def form_request(self, position: int) -> List[str]:
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
class NetworkNICEnableAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
def form_request(self, node_id: int, nic_id: int) -> List[str]:
return [
"network",
"node",
self.manager.get_node_uuid_by_idx(node_idx=node_id),
"nic",
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
"enable",
]
class NetworkNICDisableAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
def form_request(self, node_id: int, nic_id: int) -> List[str]:
return [
"network",
"node",
self.manager.get_node_uuid_by_idx(node_idx=node_id),
"nic",
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
"disable",
]
class ActionManager:
# let the action manager handle the conversion of action spaces into a single discrete integer space.
#
# when action space is created, it will take subspaces and generate an action map by enumerating all possibilities,
# BUT, the action map can be provided in the config, in which case it will use that.
@@ -278,69 +356,83 @@ class ActionManager:
# 0: DONOTHING
# 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0
# 2: ......
__act_class_identifiers:Dict[str,type] = {
__act_class_identifiers: Dict[str, type] = {
"DONOTHING": DoNothingAction,
"NODE_SERVICE_SCAN": NodeServiceScanAction,
"NODE_SERVICE_STOP": NodeServiceStopAction,
# "NODE_SERVICE_START": NodeServiceStartAction,
# "NODE_SERVICE_PAUSE": NodeServicePauseAction,
# "NODE_SERVICE_RESUME": NodeServiceResumeAction,
# "NODE_SERVICE_RESTART": NodeServiceRestartAction,
# "NODE_SERVICE_DISABLE": NodeServiceDisableAction,
# "NODE_SERVICE_ENABLE": NodeServiceEnableAction,
# "NODE_FILE_SCAN": NodeFileScanAction,
# "NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
# "NODE_FILE_DELETE": NodeFileDeleteAction,
# "NODE_FILE_REPAIR": NodeFileRepairAction,
# "NODE_FILE_RESTORE": NodeFileRestoreAction,
"NODE_SERVICE_START": NodeServiceStartAction,
"NODE_SERVICE_PAUSE": NodeServicePauseAction,
"NODE_SERVICE_RESUME": NodeServiceResumeAction,
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
"NODE_FILE_SCAN": NodeFileScanAction,
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
"NODE_FILE_DELETE": NodeFileDeleteAction,
"NODE_FILE_REPAIR": NodeFileRepairAction,
"NODE_FILE_RESTORE": NodeFileRestoreAction,
"NODE_FOLDER_SCAN": NodeFolderScanAction,
# "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
# "NODE_FOLDER_REPAIR": NodeFolderRepairAction,
# "NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
# "NODE_OS_SCAN": NodeOSScanAction,
# "NODE_SHUTDOWN": NodeShutdownAction,
# "NODE_STARTUP": NodeStartupAction,
# "NODE_RESET": NodeResetAction,
# "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
# "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
# "NETWORK_NIC_ENABLE": NetworkNICEnable,
# "NETWORK_NIC_DISABLE": NetworkNICDisable,
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
"NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
"NODE_OS_SCAN": NodeOSScanAction,
"NODE_SHUTDOWN": NodeShutdownAction,
"NODE_STARTUP": NodeStartupAction,
"NODE_RESET": NodeResetAction,
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
}
def __init__(
self,
session: PrimaiteSession, # reference to session for looking up stuff
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
max_folders_per_node: int = 2, # allows calculating shape
max_files_per_folder: int = 2, # allows calculating shape
max_services_per_node: int = 2, # allows calculating shape
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
self.session: PrimaiteSession = session
self.sim: Simulation = self.session.simulation
self.node_uuids: List[str] = node_uuids
self.protocols: List[str] = protocols
self.ports: List[str] = ports
def __init__(self,
sim:Simulation,
actions:List[str],
node_uuids:List[str],
max_folders_per_node:int = 2,
max_files_per_folder:int = 2,
max_services_per_node:int = 2,
max_nics_per_node:int=8,
max_acl_rules:int=10,
protocols:List[str]=['TCP','UDP','ICMP'],
ports:List[str]=['HTTP','DNS','ARP','FTP'],
ip_address_list:Optional[List[str]]=None,
act_map:Optional[Dict[int, Dict]]=None) -> None:
self.sim: Simulation = sim
self.node_uuids:List[str] = node_uuids
self.protocols:List[str] = protocols
self.ports:List[str] = ports
self.ip_address_list: List[str]
if ip_address_list is not None:
self.ip_address_list = ip_address_list
else:
self.ip_address_list = []
for node_uuid in self.node_uuids:
node_obj = self.sim.network.nodes[node_uuid]
nics = node_obj.nics
for nic_uuid, nic_obj in nics.items():
self.ip_address_list.append(nic_obj.ip_address)
action_args = {
"num_nodes": len(node_uuids),
"num_folders":max_folders_per_node,
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols),
"num_ports": len(self.protocols),
"num_ips":}
"num_ips": len(self.ip_address_list),
}
self.actions: Dict[str, AbstractAction] = {}
for act_type in actions:
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args)
self.action_map:Dict[int, Tuple[str, Dict]] = {}
self.action_map: Dict[int, Tuple[str, Dict]] = {}
"""
Action mapping that converts an integer to a specific action and parameter choice.
@@ -350,21 +442,30 @@ class ActionManager:
if act_map is None:
self.action_map = self._enumerate_actions()
else:
self.action_map = {i:(a['action'], a['options']) for i,a in act_map.items()}
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
# make sure all numbers between 0 and N are represented as dict keys in action map
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
def _enumerate_actions(self,) -> Dict[int, Tuple[AbstractAction, Dict]]:
...
def _enumerate_actions(
self,
) -> Dict[int, Tuple[AbstractAction, Dict]]:
all_action_possibilities = []
for action in self.actions.values():
param_names = (list(action.shape.keys()),)
num_possibilities = list(action.shape.values())
possibilities = [range(n) for n in num_possibilities]
def get_action(self, action: int) -> Tuple[str,Dict]:
itertools.product(action.shape.values())
all_action_possibilities.append((action, {}))
def get_action(self, action: int) -> Tuple[str, Dict]:
"""Produce action in CAOS format"""
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
"""The caos format is basically a action identifier, followed by parameters stored in a dictionary"""
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier:str, action_options:Dict):
def form_request(self, action_identifier: str, action_options: Dict):
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format"""
act_obj = self.actions[action_identifier]
return act_obj.form_request(**action_options)
@@ -380,37 +481,57 @@ class ActionManager:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
return folder_uuids[folder_idx] if len(folder_uuids)>folder_idx else None
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
if len(folder_uuids)<=folder_idx:
if len(folder_uuids) <= folder_idx:
return None
folder = node.file_system.folders[folder_uuids[folder_idx]]
file_uuids = list(folder.files.keys())
return file_uuids[file_idx] if len(file_uuids)>file_idx else None
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids)>service_idx else None
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
def get_internet_protocol_by_idx(self, protocol_idx:int) -> str:
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
return self.protocols[protocol_idx]
def get_ip_address_by_idx(self, ip_idx: int) -> str:
return self.ip_address_list[ip_idx]
# protocol = self.manager.get_internet_protocol_by_idx(protocol_idx)
# src_ip = self.manager.get_ip_address_by_idx(source_ip_idx)
# src_port = self.manager.get_port_by_idx(source_port_idx)
# dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx)
# dst_port = self.manager.get_port_by_idx(dest_port_idx)
def get_port_by_idx(self, port_idx: int) -> str:
return self.ports[port_idx]
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
node_uuid = self.get_node_uuid_by_idx(node_idx)
node_obj = self.sim.network.nodes[node_uuid]
nics = list(node_obj.nics.keys())
if len(nics) <= nic_idx:
return None
return nics[nic_idx]
@classmethod
def from_config(cls, session: PrimaiteSession, cfg: Dict) -> "ActionManager":
obj = cls(
session=session,
actions=cfg["action_list"],
node_uuids=cfg["options"]["nodes"],
max_folders_per_node=cfg["options"]["max_folders_per_node"],
max_files_per_folder=cfg["options"]["max_files_per_folder"],
max_services_per_node=cfg["options"]["max_services_per_node"],
max_nics_per_node=cfg["options"]["max_nics_per_node"],
max_acl_rules=cfg["options"]["max_acl_rules"],
max_X=cfg["options"]["max_X"],
protocols=session.options.ports,
ports=session.options.protocols,
ip_address_list=None,
act_map=cfg["action_map"],
)
class UC2RedActions(AbstractAction):
...
class UC2GreenActionSpace(ActionManager):
...
return obj

View File

@@ -2,14 +2,16 @@
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
# If you disagree, make a comment in the PR review and we can discuss
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union, TypeAlias
from typing import Any, Dict, List, Optional, TypeAlias, Union
import numpy as np
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
ObsType:TypeAlias = Union[Dict, np.ndarray]
ObsType: TypeAlias = Union[Dict, np.ndarray]
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
@@ -28,31 +30,28 @@ class AbstractAgent(ABC):
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
def get_obs_from_state(self, state:Dict) -> ObsType:
def convert_state_to_obs(self, state: Dict) -> ObsType:
"""
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
def get_reward_from_state(self, state:Dict) -> float:
def calculate_reward_from_state(self, state: Dict) -> float:
return self.reward_function.calculate(state)
@abstractmethod
def get_action(self, obs:ObsType, reward:float=None):
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40,
def get_action(self, obs: ObsType, reward: float = None):
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ('NODE', 'SERVICE', 'SCAN', '<fake-node-sid>', '<fake-service-sid>')
return ("NODE", "SERVICE", "SCAN", "<fake-node-sid>", "<fake-service-sid>")
@abstractmethod
def format_request(self, action) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
return ['network', 'nodes', '<fake-node-uuid>', 'file_system', 'folder', 'root', 'scan']
return ["network", "nodes", "<fake-node-uuid>", "file_system", "folder", "root", "scan"]
class AbstractScriptedAgent(AbstractAgent):
@@ -60,10 +59,11 @@ class AbstractScriptedAgent(AbstractAgent):
...
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs:ObsType, reward:float=None):
def get_action(self, obs: ObsType, reward: float = None):
return self.action_space.space.sample()

View File

@@ -5,23 +5,58 @@
# 4. Create connection with ARCD GATE
# 5. idk
from primaite.simulator.sim_container import Simulation
from primaite.game.agent.interface import AbstractAgent
from ipaddress import IPv4Address
from typing import Dict, List
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations import (
AclObservation,
FileObservation,
FolderObservation,
ICSObservation,
LinkObservation,
NicObservation,
NodeObservation,
NullObservation,
ServiceObservation,
UC2BlueObservation,
UC2RedObservation,
)
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
class PrimaiteSessionOptions(BaseModel):
ports: List[str]
protocols: List[str]
from typing import List
class PrimaiteSession:
def __init__(self):
self.simulation: Simulation = Simulation()
self.agents:List[AbstractAgent] = []
self.step_counter:int = 0
self.episode_counter:int = 0
self.agents: List[AbstractAgent] = []
self.step_counter: int = 0
self.episode_counter: int = 0
self.options: PrimaiteSessionOptions
def step(self):
# currently designed with assumption that all agents act once per step in order
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
@@ -29,10 +64,10 @@ class PrimaiteSession:
sim_state = self.simulation.describe_state()
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.get_obs_from_state(sim_state)
agent_obs = agent.convert_state_to_obs(sim_state)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.get_reward_from_state(sim_state)
agent_reward = agent.calculate_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
@@ -50,3 +85,237 @@ class PrimaiteSession:
self.simulation.apply_timestep(self.step_counter)
self.step_counter += 1
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
sess = cls()
sim = sess.simulation
net = sim.network
ref_map_nodes: Dict[str, Node] = {}
ref_map_services: Dict[str, Service] = {}
ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
for node_cfg in nodes_cfg:
node_ref = node_cfg["ref"]
n_type = node_cfg["type"]
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg["dns_server"],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg.get("dns_server"),
)
elif n_type == "switch":
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
elif n_type == "router":
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
# this: 'r_cfg.get('src_port')'
# Port/IPProtocol. TODO Refactor
new_node.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
position=r_num,
)
else:
print("invalid node type")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
service_types_mapping = {
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
"DNSServer": DNSServer,
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
# 'database_backup': ,
"DataManipulationBot": DataManipulationBot,
# 'web_browser'
}
if service_type in service_types_mapping:
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
ref_map_services[service_ref] = new_service
else:
print(f"service type not found {service_type}")
# service-dependent options
if service_type == "DatabaseClient":
if "options" in service_cfg:
opt = service_cfg["options"]
if "db_server_ip" in opt:
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
if service_type == "DNSServer":
if "options" in service_cfg:
opt = service_cfg["options"]
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
net.add_node(new_node)
new_node.power_on()
ref_map_nodes[node_ref] = new_node.uuid
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
if isinstance(node_b, Switch):
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
ref_map_links[link_cfg["ref"]] = new_link.uuid
# 3. create agents
game_cfg = cfg["game_config"]
ports_cfg = game_cfg["ports"]
protocols_cfg = game_cfg["protocols"]
agents_cfg = game_cfg["agents"]
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"]
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
if observation_space_cfg is None:
obs_space = NullObservation()
elif observation_space_cfg["type"] == "UC2BlueObservation":
node_obs_list = []
link_obs_list = []
# node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses
node_ip_to_index = {}
for node_idx, node_cfg in enumerate(nodes_cfg):
n_ref = node_cfg["ref"]
n_obj = net.nodes[ref_map_nodes[n_ref]]
for nic_uuid, nic_obj in n_obj.nics.items():
node_ip_to_index[nic_obj.ip_address] = node_idx + 2
for node_obs_cfg in observation_space_cfg["options"]["nodes"]:
node_ref = node_obs_cfg["node_ref"]
folder_obs_list = []
service_obs_list = []
if "services" in node_obs_cfg:
for service_obs_cfg in node_obs_cfg["services"]:
service_obs_list.append(
ServiceObservation(
where=[
"network",
"nodes",
ref_map_nodes[node_ref],
"services",
ref_map_services[service_obs_cfg["service_ref"]],
]
)
)
if "folders" in node_obs_cfg:
for folder_obs_cfg in node_obs_cfg["folders"]:
file_obs_list = []
if "files" in folder_obs_cfg:
for file_obs_cfg in folder_obs_cfg["files"]:
file_obs_list.append(
FileObservation(
where=[
"network",
"nodes",
ref_map_nodes[node_ref],
"folders",
folder_obs_cfg["folder_name"],
"files",
file_obs_cfg["file_name"],
]
)
)
folder_obs_list.append(
FolderObservation(
where=[
"network",
"nodes",
ref_map_nodes[node_ref],
"folders",
folder_obs_cfg["folder_name"],
],
files=file_obs_list,
)
)
nic_obs_list = []
for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg["node_ref"]]].nics.keys():
nic_obs_list.append(
NicObservation(where=["network", "nodes", ref_map_nodes[node_ref], "NICs", nic_uuid])
)
node_obs_list.append(
NodeObservation(
where=["network", "nodes", ref_map_nodes[node_ref]],
services=service_obs_list,
folders=folder_obs_list,
nics=nic_obs_list,
logon_status=False,
)
)
for link_obs_cfg in observation_space_cfg["options"]["links"]:
link_ref = link_obs_cfg["link_ref"]
link_obs_list.append(LinkObservation(where=["network", "links", ref_map_links[link_ref]]))
acl_obs = AclObservation(
node_ip_to_id=node_ip_to_index,
ports=game_cfg["ports"],
protocols=game_cfg["ports"],
where=["network", "nodes", observation_space_cfg["options"]["acl"]["router_node_ref"]],
)
obs_space = UC2BlueObservation(
nodes=node_obs_list, links=link_obs_list, acl=acl_obs, ics=ICSObservation()
)
elif observation_space_cfg["type"] == "UC2RedObservation":
obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim)
else:
print("observation space config not specified correctly.")
obs_space = NullObservation()
# CREATE ACTION SPACE
action_space = ActionManager.from_config(sess, action_space_cfg)
# CREATE REWARD FUNCTION
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
...
elif agent_type == "GATERLAgent":
...
elif agent_type == "RedDatabaseCorruptingAgent":
...
else:
print("agent type not found")
return sess