Add network action
This commit is contained in:
@@ -131,32 +131,36 @@ game_config:
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- DONOTHING
|
||||
- NODE_SERVICE_SCAN
|
||||
- NODE_SERVICE_STOP
|
||||
# - NODE_SERVICE_START
|
||||
# - NODE_SERVICE_PAUSE
|
||||
# - NODE_SERVICE_RESUME
|
||||
# - NODE_SERVICE_RESTART
|
||||
# - NODE_SERVICE_DISABLE
|
||||
# - NODE_SERVICE_ENABLE
|
||||
# - NODE_FILE_SCAN
|
||||
# - NODE_FILE_CHECKHASH
|
||||
# - NODE_FILE_DELETE
|
||||
# - NODE_FILE_REPAIR
|
||||
# - NODE_FILE_RESTORE
|
||||
# - NODE_FOLDER_SCAN
|
||||
# - NODE_FOLDER_CHECKHASH
|
||||
# - NODE_FOLDER_REPAIR
|
||||
# - NODE_FOLDER_RESTORE
|
||||
# - NODE_OS_SCAN
|
||||
# - NODE_SHUTDOWN
|
||||
# - NODE_STARTUP
|
||||
# - NODE_RESET
|
||||
# - NETWORK_ACL_ADDRULE
|
||||
# - NETWORK_ACL_REMOVERULE
|
||||
# - NETWORK_NIC_ENABLE
|
||||
- NETWORK_NIC_DISABLE
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import itertools
|
||||
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
class ExecutionDefiniton(ABC):
|
||||
"""
|
||||
Converter from actions to simulator requests.
|
||||
@@ -23,9 +24,8 @@ class ExecutionDefiniton(ABC):
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
"""
|
||||
Init method for action.
|
||||
|
||||
@@ -35,13 +35,12 @@ class AbstractAction(ABC):
|
||||
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
|
||||
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
|
||||
"""
|
||||
self.name:str = ""
|
||||
self.name: str = ""
|
||||
"""Human-readable action identifier used for printing, logging, and reporting."""
|
||||
self.shape = (0,)
|
||||
"""Tuple describing number of options for each parameter of this action. Can be passed to
|
||||
gym.spaces.MultiDiscrete to form a valid space."""
|
||||
self.manager:ActionManager = manager
|
||||
|
||||
self.shape: Dict[str, int] = {}
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
@@ -50,14 +49,20 @@ class AbstractAction(ABC):
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
def __init__(self, manager:"ActionManager", **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.name = "DONOTHING"
|
||||
self.shape = (1,)
|
||||
self.shape: Dict[str, int] = {
|
||||
"dummy": 1,
|
||||
}
|
||||
# This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1),
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self) -> List[str]:
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for service actions.
|
||||
@@ -65,211 +70,284 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Tuple[int] = (num_nodes, num_services)
|
||||
self.verb:str
|
||||
|
||||
def form_request(self, node_id:int, service_id:int) -> List[str]:
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
|
||||
if node_uuid is None or service_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ['network', 'node', node_uuid, 'services', service_uuid, self.verb]
|
||||
return ["network", "node", node_uuid, "services", service_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape = (num_nodes, num_folders)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id:int, folder_id:int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
if node_uuid is None or folder_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, self.verb]
|
||||
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "scan"
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "checkhash"
|
||||
self.verb: str = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager:"ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "repair"
|
||||
self.verb: str = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
self.verb:str = "restore"
|
||||
self.verb: str = "restore"
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager:"ActionManager", num_nodes:int, num_folders:int, num_files:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape:Tuple[int] = (num_nodes, num_folders, num_files)
|
||||
self.verb:str
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id:int, folder_id:int, file_id:int) -> List[str]:
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ['network', 'node', node_uuid, 'file_system', 'folder', folder_uuid, 'files', file_uuid, self.verb]
|
||||
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, "files", file_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "checkhash"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "delete"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "repair"
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
self.verb = "restore"
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Tuple[int] = (num_nodes,)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id:int) -> List[str]:
|
||||
def form_request(self, node_id: int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
return ["network", "node", node_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'scan'
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'shutdown'
|
||||
self.verb = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'start'
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.verb = 'reset'
|
||||
self.verb = "reset"
|
||||
|
||||
|
||||
class NetworkACLAddRuleAction(AbstractAction):
|
||||
def __init__(self,
|
||||
manager: "ActionManager",
|
||||
target_router_uuid:str,
|
||||
max_acl_rules:int,
|
||||
num_ips:int,
|
||||
num_ports:int,
|
||||
num_protocols:int,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
target_router_uuid: str,
|
||||
max_acl_rules: int,
|
||||
num_ips: int,
|
||||
num_ports: int,
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 2
|
||||
self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_ips, num_ips, num_ports, num_ports, num_protocols)
|
||||
self.target_router_uuid:str = target_router_uuid
|
||||
self.shape: Dict[str, int] = {
|
||||
"position": max_acl_rules,
|
||||
"permission": num_permissions,
|
||||
"source_ip_idx": num_ips,
|
||||
"dest_ip_idx": num_ips,
|
||||
"source_port_idx": num_ports,
|
||||
"dest_port_idx": num_ports,
|
||||
"protocol_idx": num_protocols,
|
||||
}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx) -> List[str]:
|
||||
def form_request(
|
||||
self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx
|
||||
) -> List[str]:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_idx)
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_idx)
|
||||
src_port = self.manager.get_port_by_idx(source_port_idx)
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx)
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_idx)
|
||||
return [
|
||||
'network',
|
||||
'node',
|
||||
"network",
|
||||
"node",
|
||||
self.target_router_uuid,
|
||||
'acl',
|
||||
'add_rule',
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission,
|
||||
protocol,
|
||||
src_ip,
|
||||
src_port,
|
||||
dst_ip,
|
||||
dst_port,
|
||||
position
|
||||
position,
|
||||
]
|
||||
|
||||
|
||||
class NetworkACLRemoveRuleAction(AbstractAction):
|
||||
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(self, position: int) -> List[str]:
|
||||
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
|
||||
|
||||
|
||||
class NetworkNICEnableAction(AbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.manager.get_node_uuid_by_idx(node_idx=node_id),
|
||||
"nic",
|
||||
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
|
||||
"enable",
|
||||
]
|
||||
|
||||
|
||||
class NetworkNICDisableAction(AbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.manager.get_node_uuid_by_idx(node_idx=node_id),
|
||||
"nic",
|
||||
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
|
||||
"disable",
|
||||
]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
# let the action manager handle the conversion of action spaces into a single discrete integer space.
|
||||
#
|
||||
|
||||
|
||||
# when action space is created, it will take subspaces and generate an action map by enumerating all possibilities,
|
||||
# BUT, the action map can be provided in the config, in which case it will use that.
|
||||
|
||||
@@ -278,69 +356,83 @@ class ActionManager:
|
||||
# 0: DONOTHING
|
||||
# 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0
|
||||
# 2: ......
|
||||
__act_class_identifiers:Dict[str,type] = {
|
||||
__act_class_identifiers: Dict[str, type] = {
|
||||
"DONOTHING": DoNothingAction,
|
||||
"NODE_SERVICE_SCAN": NodeServiceScanAction,
|
||||
"NODE_SERVICE_STOP": NodeServiceStopAction,
|
||||
# "NODE_SERVICE_START": NodeServiceStartAction,
|
||||
# "NODE_SERVICE_PAUSE": NodeServicePauseAction,
|
||||
# "NODE_SERVICE_RESUME": NodeServiceResumeAction,
|
||||
# "NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
# "NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
# "NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
# "NODE_FILE_SCAN": NodeFileScanAction,
|
||||
# "NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
# "NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
# "NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
# "NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_SERVICE_START": NodeServiceStartAction,
|
||||
"NODE_SERVICE_PAUSE": NodeServicePauseAction,
|
||||
"NODE_SERVICE_RESUME": NodeServiceResumeAction,
|
||||
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
"NODE_FILE_SCAN": NodeFileScanAction,
|
||||
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
"NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
"NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_FOLDER_SCAN": NodeFolderScanAction,
|
||||
# "NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
# "NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
# "NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
|
||||
# "NODE_OS_SCAN": NodeOSScanAction,
|
||||
# "NODE_SHUTDOWN": NodeShutdownAction,
|
||||
# "NODE_STARTUP": NodeStartupAction,
|
||||
# "NODE_RESET": NodeResetAction,
|
||||
# "NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
|
||||
# "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
# "NETWORK_NIC_ENABLE": NetworkNICEnable,
|
||||
# "NETWORK_NIC_DISABLE": NetworkNICDisable,
|
||||
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
"NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
|
||||
"NODE_OS_SCAN": NodeOSScanAction,
|
||||
"NODE_SHUTDOWN": NodeShutdownAction,
|
||||
"NODE_STARTUP": NodeStartupAction,
|
||||
"NODE_RESET": NodeResetAction,
|
||||
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: PrimaiteSession, # reference to session for looking up stuff
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
max_files_per_folder: int = 2, # allows calculating shape
|
||||
max_services_per_node: int = 2, # allows calculating shape
|
||||
max_nics_per_node: int = 8, # allows calculating shape
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
self.session: PrimaiteSession = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
def __init__(self,
|
||||
sim:Simulation,
|
||||
actions:List[str],
|
||||
node_uuids:List[str],
|
||||
max_folders_per_node:int = 2,
|
||||
max_files_per_folder:int = 2,
|
||||
max_services_per_node:int = 2,
|
||||
max_nics_per_node:int=8,
|
||||
max_acl_rules:int=10,
|
||||
protocols:List[str]=['TCP','UDP','ICMP'],
|
||||
ports:List[str]=['HTTP','DNS','ARP','FTP'],
|
||||
ip_address_list:Optional[List[str]]=None,
|
||||
act_map:Optional[Dict[int, Dict]]=None) -> None:
|
||||
self.sim: Simulation = sim
|
||||
self.node_uuids:List[str] = node_uuids
|
||||
self.protocols:List[str] = protocols
|
||||
self.ports:List[str] = ports
|
||||
self.ip_address_list: List[str]
|
||||
if ip_address_list is not None:
|
||||
self.ip_address_list = ip_address_list
|
||||
else:
|
||||
self.ip_address_list = []
|
||||
for node_uuid in self.node_uuids:
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = node_obj.nics
|
||||
for nic_uuid, nic_obj in nics.items():
|
||||
self.ip_address_list.append(nic_obj.ip_address)
|
||||
|
||||
action_args = {
|
||||
"num_nodes": len(node_uuids),
|
||||
"num_folders":max_folders_per_node,
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules,
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips":}
|
||||
"num_ips": len(self.ip_address_list),
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_type in actions:
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args)
|
||||
|
||||
self.action_map:Dict[int, Tuple[str, Dict]] = {}
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
@@ -350,21 +442,30 @@ class ActionManager:
|
||||
if act_map is None:
|
||||
self.action_map = self._enumerate_actions()
|
||||
else:
|
||||
self.action_map = {i:(a['action'], a['options']) for i,a in act_map.items()}
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(self,) -> Dict[int, Tuple[AbstractAction, Dict]]:
|
||||
...
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[AbstractAction, Dict]]:
|
||||
all_action_possibilities = []
|
||||
for action in self.actions.values():
|
||||
param_names = (list(action.shape.keys()),)
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str,Dict]:
|
||||
itertools.product(action.shape.values())
|
||||
all_action_possibilities.append((action, {}))
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format"""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The caos format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier:str, action_options:Dict):
|
||||
def form_request(self, action_identifier: str, action_options: Dict):
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format"""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
@@ -380,37 +481,57 @@ class ActionManager:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
return folder_uuids[folder_idx] if len(folder_uuids)>folder_idx else None
|
||||
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
|
||||
|
||||
def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
if len(folder_uuids)<=folder_idx:
|
||||
if len(folder_uuids) <= folder_idx:
|
||||
return None
|
||||
folder = node.file_system.folders[folder_uuids[folder_idx]]
|
||||
file_uuids = list(folder.files.keys())
|
||||
return file_uuids[file_idx] if len(file_uuids)>file_idx else None
|
||||
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
|
||||
|
||||
def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
service_uuids = list(node.services.keys())
|
||||
return service_uuids[service_idx] if len(service_uuids)>service_idx else None
|
||||
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx:int) -> str:
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
# protocol = self.manager.get_internet_protocol_by_idx(protocol_idx)
|
||||
# src_ip = self.manager.get_ip_address_by_idx(source_ip_idx)
|
||||
# src_port = self.manager.get_port_by_idx(source_port_idx)
|
||||
# dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx)
|
||||
# dst_port = self.manager.get_port_by_idx(dest_port_idx)
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = list(node_obj.nics.keys())
|
||||
if len(nics) <= nic_idx:
|
||||
return None
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: PrimaiteSession, cfg: Dict) -> "ActionManager":
|
||||
obj = cls(
|
||||
session=session,
|
||||
actions=cfg["action_list"],
|
||||
node_uuids=cfg["options"]["nodes"],
|
||||
max_folders_per_node=cfg["options"]["max_folders_per_node"],
|
||||
max_files_per_folder=cfg["options"]["max_files_per_folder"],
|
||||
max_services_per_node=cfg["options"]["max_services_per_node"],
|
||||
max_nics_per_node=cfg["options"]["max_nics_per_node"],
|
||||
max_acl_rules=cfg["options"]["max_acl_rules"],
|
||||
max_X=cfg["options"]["max_X"],
|
||||
protocols=session.options.ports,
|
||||
ports=session.options.protocols,
|
||||
ip_address_list=None,
|
||||
act_map=cfg["action_map"],
|
||||
)
|
||||
|
||||
class UC2RedActions(AbstractAction):
|
||||
...
|
||||
|
||||
class UC2GreenActionSpace(ActionManager):
|
||||
...
|
||||
return obj
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
|
||||
# If you disagree, make a comment in the PR review and we can discuss
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union, TypeAlias
|
||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
ObsType:TypeAlias = Union[Dict, np.ndarray]
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
@@ -28,31 +30,28 @@ class AbstractAgent(ABC):
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
|
||||
def get_obs_from_state(self, state:Dict) -> ObsType:
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
|
||||
def get_reward_from_state(self, state:Dict) -> float:
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
return self.reward_function.calculate(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs:ObsType, reward:float=None):
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40,
|
||||
def get_action(self, obs: ObsType, reward: float = None):
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ('NODE', 'SERVICE', 'SCAN', '<fake-node-sid>', '<fake-service-sid>')
|
||||
return ("NODE", "SERVICE", "SCAN", "<fake-node-sid>", "<fake-service-sid>")
|
||||
|
||||
@abstractmethod
|
||||
def format_request(self, action) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
return ['network', 'nodes', '<fake-node-uuid>', 'file_system', 'folder', 'root', 'scan']
|
||||
|
||||
|
||||
|
||||
return ["network", "nodes", "<fake-node-uuid>", "file_system", "folder", "root", "scan"]
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
@@ -60,10 +59,11 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
|
||||
...
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs:ObsType, reward:float=None):
|
||||
def get_action(self, obs: ObsType, reward: float = None):
|
||||
return self.action_space.space.sample()
|
||||
|
||||
|
||||
|
||||
@@ -5,23 +5,58 @@
|
||||
# 4. Create connection with ARCD GATE
|
||||
# 5. idk
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.observations import (
|
||||
AclObservation,
|
||||
FileObservation,
|
||||
FolderObservation,
|
||||
ICSObservation,
|
||||
LinkObservation,
|
||||
NicObservation,
|
||||
NodeObservation,
|
||||
NullObservation,
|
||||
ServiceObservation,
|
||||
UC2BlueObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
from typing import List
|
||||
|
||||
class PrimaiteSession:
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
self.agents:List[AbstractAgent] = []
|
||||
self.step_counter:int = 0
|
||||
self.episode_counter:int = 0
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
self.step_counter: int = 0
|
||||
self.episode_counter: int = 0
|
||||
self.options: PrimaiteSessionOptions
|
||||
|
||||
def step(self):
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
@@ -29,10 +64,10 @@ class PrimaiteSession:
|
||||
sim_state = self.simulation.describe_state()
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.get_obs_from_state(sim_state)
|
||||
agent_obs = agent.convert_state_to_obs(sim_state)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.get_reward_from_state(sim_state)
|
||||
agent_reward = agent.calculate_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
@@ -50,3 +85,237 @@ class PrimaiteSession:
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
self.step_counter += 1
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
sess = cls()
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
ref_map_nodes: Dict[str, Node] = {}
|
||||
ref_map_services: Dict[str, Service] = {}
|
||||
ref_map_links: Dict[str, Link] = {}
|
||||
|
||||
nodes_cfg = cfg["simulation"]["network"]["nodes"]
|
||||
links_cfg = cfg["simulation"]["network"]["links"]
|
||||
for node_cfg in nodes_cfg:
|
||||
node_ref = node_cfg["ref"]
|
||||
n_type = node_cfg["type"]
|
||||
if n_type == "computer":
|
||||
new_node = Computer(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
default_gateway=node_cfg["default_gateway"],
|
||||
dns_server=node_cfg["dns_server"],
|
||||
)
|
||||
elif n_type == "server":
|
||||
new_node = Server(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
default_gateway=node_cfg["default_gateway"],
|
||||
dns_server=node_cfg.get("dns_server"),
|
||||
)
|
||||
elif n_type == "switch":
|
||||
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
|
||||
elif n_type == "router":
|
||||
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
|
||||
if "ports" in node_cfg:
|
||||
for port_num, port_cfg in node_cfg["ports"].items():
|
||||
new_node.configure_port(
|
||||
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
|
||||
)
|
||||
if "acl" in node_cfg:
|
||||
for r_num, r_cfg in node_cfg["acl"].items():
|
||||
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
|
||||
# this: 'r_cfg.get('src_port')'
|
||||
# Port/IPProtocol. TODO Refactor
|
||||
new_node.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("ip_address"),
|
||||
dst_ip_address=r_cfg.get("ip_address"),
|
||||
position=r_num,
|
||||
)
|
||||
else:
|
||||
print("invalid node type")
|
||||
if "services" in node_cfg:
|
||||
for service_cfg in node_cfg["services"]:
|
||||
service_ref = service_cfg["ref"]
|
||||
service_type = service_cfg["type"]
|
||||
service_types_mapping = {
|
||||
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
|
||||
"DNSServer": DNSServer,
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DatabaseService": DatabaseService,
|
||||
# 'database_backup': ,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
# 'web_browser'
|
||||
}
|
||||
if service_type in service_types_mapping:
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
ref_map_services[service_ref] = new_service
|
||||
else:
|
||||
print(f"service type not found {service_type}")
|
||||
# service-dependent options
|
||||
if service_type == "DatabaseClient":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "db_server_ip" in opt:
|
||||
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
|
||||
if service_type == "DNSServer":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "domain_mapping" in opt:
|
||||
for domain, ip in opt["domain_mapping"].items():
|
||||
new_service.dns_register(domain, ip)
|
||||
if "nics" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["nics"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
ref_map_nodes[node_ref] = new_node.uuid
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.nodes[ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
|
||||
if isinstance(node_b, Switch):
|
||||
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
|
||||
else:
|
||||
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
|
||||
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
|
||||
# 3. create agents
|
||||
game_cfg = cfg["game_config"]
|
||||
ports_cfg = game_cfg["ports"]
|
||||
protocols_cfg = game_cfg["protocols"]
|
||||
agents_cfg = game_cfg["agents"]
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"]
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
if observation_space_cfg is None:
|
||||
obs_space = NullObservation()
|
||||
elif observation_space_cfg["type"] == "UC2BlueObservation":
|
||||
node_obs_list = []
|
||||
link_obs_list = []
|
||||
|
||||
# node ip to index maps ip addresses to node id, as there are potentially multiple nics on a node, there are multiple ip addresses
|
||||
node_ip_to_index = {}
|
||||
for node_idx, node_cfg in enumerate(nodes_cfg):
|
||||
n_ref = node_cfg["ref"]
|
||||
n_obj = net.nodes[ref_map_nodes[n_ref]]
|
||||
for nic_uuid, nic_obj in n_obj.nics.items():
|
||||
node_ip_to_index[nic_obj.ip_address] = node_idx + 2
|
||||
|
||||
for node_obs_cfg in observation_space_cfg["options"]["nodes"]:
|
||||
node_ref = node_obs_cfg["node_ref"]
|
||||
folder_obs_list = []
|
||||
service_obs_list = []
|
||||
if "services" in node_obs_cfg:
|
||||
for service_obs_cfg in node_obs_cfg["services"]:
|
||||
service_obs_list.append(
|
||||
ServiceObservation(
|
||||
where=[
|
||||
"network",
|
||||
"nodes",
|
||||
ref_map_nodes[node_ref],
|
||||
"services",
|
||||
ref_map_services[service_obs_cfg["service_ref"]],
|
||||
]
|
||||
)
|
||||
)
|
||||
if "folders" in node_obs_cfg:
|
||||
for folder_obs_cfg in node_obs_cfg["folders"]:
|
||||
file_obs_list = []
|
||||
if "files" in folder_obs_cfg:
|
||||
for file_obs_cfg in folder_obs_cfg["files"]:
|
||||
file_obs_list.append(
|
||||
FileObservation(
|
||||
where=[
|
||||
"network",
|
||||
"nodes",
|
||||
ref_map_nodes[node_ref],
|
||||
"folders",
|
||||
folder_obs_cfg["folder_name"],
|
||||
"files",
|
||||
file_obs_cfg["file_name"],
|
||||
]
|
||||
)
|
||||
)
|
||||
folder_obs_list.append(
|
||||
FolderObservation(
|
||||
where=[
|
||||
"network",
|
||||
"nodes",
|
||||
ref_map_nodes[node_ref],
|
||||
"folders",
|
||||
folder_obs_cfg["folder_name"],
|
||||
],
|
||||
files=file_obs_list,
|
||||
)
|
||||
)
|
||||
nic_obs_list = []
|
||||
for nic_uuid in net.nodes[ref_map_nodes[node_obs_cfg["node_ref"]]].nics.keys():
|
||||
nic_obs_list.append(
|
||||
NicObservation(where=["network", "nodes", ref_map_nodes[node_ref], "NICs", nic_uuid])
|
||||
)
|
||||
node_obs_list.append(
|
||||
NodeObservation(
|
||||
where=["network", "nodes", ref_map_nodes[node_ref]],
|
||||
services=service_obs_list,
|
||||
folders=folder_obs_list,
|
||||
nics=nic_obs_list,
|
||||
logon_status=False,
|
||||
)
|
||||
)
|
||||
for link_obs_cfg in observation_space_cfg["options"]["links"]:
|
||||
link_ref = link_obs_cfg["link_ref"]
|
||||
link_obs_list.append(LinkObservation(where=["network", "links", ref_map_links[link_ref]]))
|
||||
|
||||
acl_obs = AclObservation(
|
||||
node_ip_to_id=node_ip_to_index,
|
||||
ports=game_cfg["ports"],
|
||||
protocols=game_cfg["ports"],
|
||||
where=["network", "nodes", observation_space_cfg["options"]["acl"]["router_node_ref"]],
|
||||
)
|
||||
obs_space = UC2BlueObservation(
|
||||
nodes=node_obs_list, links=link_obs_list, acl=acl_obs, ics=ICSObservation()
|
||||
)
|
||||
elif observation_space_cfg["type"] == "UC2RedObservation":
|
||||
obs_space = UC2RedObservation.from_config(observation_space_cfg["options"], sim=sim)
|
||||
else:
|
||||
print("observation space config not specified correctly.")
|
||||
obs_space = NullObservation()
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
...
|
||||
elif agent_type == "GATERLAgent":
|
||||
...
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
...
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user