Fix bugged actions

This commit is contained in:
Marek Wolan
2023-10-09 17:29:50 +01:00
parent 081a3e519a
commit f68886d5df
5 changed files with 230 additions and 169 deletions

View File

@@ -63,8 +63,8 @@ game_config:
services:
- service_ref: data_manipulation_bot
observations:
- operating_status
- health_status
operating_status
health_status
folders: {}
action_space:
@@ -197,221 +197,221 @@ game_config:
1:
action: NODE_SERVICE_SCAN
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
- node_id: 2
- service_id: 1
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
- node_id: 3
- folder_id: 1
- file_id: 1
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
- node_id: 3
- folder_id: 1
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
- node_id: 3
- folder_id: 1
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
- node_id: 3
- folder_id: 1
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
- node_id: 3
- folder_id: 1
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
- node_id: 3
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
- node_id: 6
node_id: 6
20:
action: "NODE_STARTUP"
options:
- node_id: 6
node_id: 6
21:
action: "NODE_RESET"
options:
- node_id: 6
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 6
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 5
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 4
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 3
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 2
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
- position: 1
- permission: 2
- source_node_id: ...
- dest_node_id: ...
- source_port_id: ...
- dest_port_id: ...
- protocol_id: ...
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 0
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 1
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 2
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 3
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 4
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 5
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 6
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 7
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 8
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
- position: 9
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
- node_id: 6
- nic_index: 1
node_id: 6
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
- node_id: 6
- nic_index: 1
node_id: 6
nic_id: 1
options:
nodes:

View File

@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gym import spaces
from primaite.simulator.sim_container import Simulation
from primaite import getLogger
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
@@ -253,7 +255,7 @@ class NodeShutdownAction(NodeAbstractAction):
class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "start"
self.verb = "startup"
class NodeResetAction(NodeAbstractAction):
@@ -274,33 +276,73 @@ class NetworkACLAddRuleAction(AbstractAction):
**kwargs,
) -> None:
super().__init__(manager=manager)
num_permissions = 2
num_permissions = 3
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_idx": num_ips,
"dest_ip_idx": num_ips,
"source_port_idx": num_ports,
"dest_port_idx": num_ports,
"protocol_idx": num_protocols,
"source_ip_id": num_ips,
"dest_ip_id": num_ips,
"source_port_id": num_ports,
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
self.target_router_uuid: str = target_router_uuid
def form_request(
self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx
self, position, permission, source_ip_id, dest_ip_id, source_port_id, dest_port_id, protocol_id
) -> List[str]:
protocol = self.manager.get_internet_protocol_by_idx(protocol_idx)
src_ip = self.manager.get_ip_address_by_idx(source_ip_idx)
src_port = self.manager.get_port_by_idx(source_port_idx)
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx)
dst_port = self.manager.get_port_by_idx(dest_port_idx)
if permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
permission_str = "ALLOW"
elif permission == 2:
permission_str = "DENY"
else:
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(protocol_id-2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if source_ip_id in [0,1]:
src_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id-2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id-2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_ip_id in (0,1):
dst_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
"network",
"node",
self.target_router_uuid,
"acl",
"add_rule",
permission,
permission_str,
protocol,
src_ip,
src_port,
@@ -320,36 +362,52 @@ class NetworkACLRemoveRuleAction(AbstractAction):
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
class NetworkNICEnableAction(AbstractAction):
class NetworkNICAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str
def form_request(self, node_id: int, nic_id: int) -> List[str]:
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
if node_uuid is None or nic_uuid is None:
return ["do_nothing"]
return [
"network",
"node",
self.manager.get_node_uuid_by_idx(node_idx=node_id),
node_uuid,
"nic",
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
"enable",
nic_uuid,
self.verb,
]
class NetworkNICDisableAction(AbstractAction):
class NetworkNICEnableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "enable"
def form_request(self, node_id: int, nic_id: int) -> List[str]:
return [
"network",
"node",
self.manager.get_node_uuid_by_idx(node_idx=node_id),
"nic",
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
"disable",
]
class NetworkNICDisableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
# class NetworkNICDisableAction(AbstractAction):
# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
# super().__init__(manager=manager)
# self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
# def form_request(self, node_id: int, nic_id: int) -> List[str]:
# return [
# "network",
# "node",
# self.manager.get_node_uuid_by_idx(node_idx=node_id),
# "nic",
# self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
# "disable",
# ]
class ActionManager:

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING
from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING, Sequence, Tuple
from gym import spaces
from pydantic import BaseModel
@@ -15,7 +15,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
"""
def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any:
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
"""
Access an item from a deeply dictionary with a list of keys.
@@ -29,12 +29,13 @@ def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any:
:return: The value in the dictionary
:rtype: Any
"""
if len(keys) == 0:
key_list = [*keys] # copy keys to a new list to prevent editing original list
if len(key_list) == 0:
return dictionary
k = keys.pop(0)
k = key_list.pop(0)
if k not in dictionary:
return NOT_PRESENT_IN_STATE
return access_from_nested_dict(dictionary[k], keys)
return access_from_nested_dict(dictionary[k], key_list)
class AbstractObservation(ABC):
@@ -66,7 +67,7 @@ class AbstractObservation(ABC):
class FileObservation(AbstractObservation):
def __init__(self, where: Optional[List[str]] = None) -> None:
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
_summary_
@@ -79,7 +80,7 @@ class FileObservation(AbstractObservation):
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
@@ -104,7 +105,7 @@ class ServiceObservation(AbstractObservation):
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
def __init__(self, where: Optional[List[str]] = None) -> None:
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
@@ -115,7 +116,7 @@ class ServiceObservation(AbstractObservation):
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
if self.where is None:
@@ -124,7 +125,7 @@ class ServiceObservation(AbstractObservation):
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": service_state["operating_status"], "health_status": service_state["health_status"]}
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_status"]}
@property
def space(self) -> spaces.Space:
@@ -132,7 +133,9 @@ class ServiceObservation(AbstractObservation):
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]=None):
return cls(where=parent_where+["services",session.ref_map_services[config['service_ref']]])
return cls(
where=parent_where+["services",session.ref_map_services[config['service_ref']].uuid]
)
@@ -140,7 +143,7 @@ class LinkObservation(AbstractObservation):
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[List[str]] = None) -> None:
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
@@ -151,7 +154,7 @@ class LinkObservation(AbstractObservation):
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
if self.where is None:
@@ -180,7 +183,7 @@ class LinkObservation(AbstractObservation):
class FolderObservation(AbstractObservation):
def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None:
def __init__(self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = []) -> None:
"""Initialise folder Observation, including files inside of the folder.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
@@ -199,7 +202,7 @@ class FolderObservation(AbstractObservation):
"""
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
self.files: List[FileObservation] = files
@@ -246,9 +249,9 @@ class FolderObservation(AbstractObservation):
class NicObservation(AbstractObservation):
default_observation: spaces.Space = {"nic_status": 0}
def __init__(self, where: Optional[List[str]] = None) -> None:
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
if self.where is None:
@@ -271,7 +274,7 @@ class NicObservation(AbstractObservation):
class NodeObservation(AbstractObservation):
def __init__(
self,
where: Optional[List[str]] = None,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
nics: List[NicObservation] = [],
@@ -298,7 +301,7 @@ class NodeObservation(AbstractObservation):
:type max_nics: int, optional
"""
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
self.services: List[ServiceObservation] = services
self.folders: List[FolderObservation] = folders
@@ -371,10 +374,10 @@ class AclObservation(AbstractObservation):
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[List[str]] = None, num_rules: int = 10
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[Tuple[str]] = None, num_rules: int = 10
) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
@@ -403,6 +406,8 @@ class AclObservation(AbstractObservation):
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
#TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
obs["RULES"] = {}
for i, rule_state in acl_state.items():
@@ -466,7 +471,7 @@ class AclObservation(AbstractObservation):
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
where=["network", "nodes", router_uuid])
where=["network", "nodes", router_uuid, "acl", "acl"])
@@ -498,7 +503,7 @@ class UC2BlueObservation(AbstractObservation):
where:Optional[List[str]] = None,
) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
@@ -517,11 +522,10 @@ class UC2BlueObservation(AbstractObservation):
return self.default_observation
obs = {}
obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs['ACL'] = {self.acl.observe(state)}
obs['ICS'] = {self.ics.observe(state)}
obs['ACL'] = self.acl.observe(state)
obs['ICS'] = self.ics.observe(state)
return obs
@@ -546,7 +550,7 @@ class UC2BlueObservation(AbstractObservation):
acl = AclObservation.from_config(config=acl_config, session=session)
ics_config = config["ics"]
ics = ICSObservation.from_config(ics_config)
ics = ICSObservation.from_config(config=ics_config, session=session)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network'])
return new

View File

@@ -111,11 +111,11 @@ class AccessControlList(SimComponent):
Action(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
IPProtocol[request[1]],
IPv4Address[request[2]],
Port[request[3]],
IPv4Address[request[4]],
Port[request[5]],
None if request[1] is "ALL" else IPProtocol[request[1]],
IPv4Address(request[2]),
None if request[3] is "ALL" else Port[request[3]],
IPv4Address(request[4]),
None if request[5] is "ALL" else Port[request[5]],
int(request[6]),
)
),

View File

@@ -55,12 +55,11 @@ class Switch(Node):
:return: Current state of this object and child objects.
"""
return {
"uuid": self.uuid,
"num_ports": self.num_ports, # redundant?
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
}
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
state["num_ports"]= self.num_ports # redundant?
state["mac_address_table"]= {mac: port for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
"""