Fix bugged actions
This commit is contained in:
@@ -63,8 +63,8 @@ game_config:
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
- operating_status
|
||||
- health_status
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
@@ -197,221 +197,221 @@ game_config:
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
- node_id: 2
|
||||
- service_id: 1
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
- file_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
- node_id: 3
|
||||
- folder_id: 1
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
- node_id: 3
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
- node_id: 6
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
- node_id: 6
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
- node_id: 6
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 6
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 5
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 4
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 3
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 2
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
- position: 1
|
||||
- permission: 2
|
||||
- source_node_id: ...
|
||||
- dest_node_id: ...
|
||||
- source_port_id: ...
|
||||
- dest_port_id: ...
|
||||
- protocol_id: ...
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 0
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 1
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 2
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 3
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 4
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 5
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 6
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 7
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 8
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
- position: 9
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
- node_id: 6
|
||||
- nic_index: 1
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
- node_id: 6
|
||||
- nic_index: 1
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
|
||||
options:
|
||||
nodes:
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from gym import spaces
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
@@ -253,7 +255,7 @@ class NodeShutdownAction(NodeAbstractAction):
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "start"
|
||||
self.verb = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
@@ -274,33 +276,73 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 2
|
||||
num_permissions = 3
|
||||
self.shape: Dict[str, int] = {
|
||||
"position": max_acl_rules,
|
||||
"permission": num_permissions,
|
||||
"source_ip_idx": num_ips,
|
||||
"dest_ip_idx": num_ips,
|
||||
"source_port_idx": num_ports,
|
||||
"dest_port_idx": num_ports,
|
||||
"protocol_idx": num_protocols,
|
||||
"source_ip_id": num_ips,
|
||||
"dest_ip_id": num_ips,
|
||||
"source_port_id": num_ports,
|
||||
"dest_port_id": num_ports,
|
||||
"protocol_id": num_protocols,
|
||||
}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(
|
||||
self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx
|
||||
self, position, permission, source_ip_id, dest_ip_id, source_port_id, dest_port_id, protocol_id
|
||||
) -> List[str]:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_idx)
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_idx)
|
||||
src_port = self.manager.get_port_by_idx(source_port_idx)
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx)
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_idx)
|
||||
if permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif permission == 1:
|
||||
permission_str = "ALLOW"
|
||||
elif permission == 2:
|
||||
permission_str = "DENY"
|
||||
else:
|
||||
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
if protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id-2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id in [0,1]:
|
||||
src_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id-2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id-2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id in (0,1):
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.target_router_uuid,
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission,
|
||||
permission_str,
|
||||
protocol,
|
||||
src_ip,
|
||||
src_port,
|
||||
@@ -320,36 +362,52 @@ class NetworkACLRemoveRuleAction(AbstractAction):
|
||||
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
|
||||
|
||||
|
||||
class NetworkNICEnableAction(AbstractAction):
|
||||
class NetworkNICAbstractAction(AbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
|
||||
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
|
||||
if node_uuid is None or nic_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.manager.get_node_uuid_by_idx(node_idx=node_id),
|
||||
node_uuid,
|
||||
"nic",
|
||||
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
|
||||
"enable",
|
||||
nic_uuid,
|
||||
self.verb,
|
||||
]
|
||||
|
||||
|
||||
class NetworkNICDisableAction(AbstractAction):
|
||||
class NetworkNICEnableAction(NetworkNICAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "enable"
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.manager.get_node_uuid_by_idx(node_idx=node_id),
|
||||
"nic",
|
||||
self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
|
||||
"disable",
|
||||
]
|
||||
|
||||
class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "disable"
|
||||
|
||||
# class NetworkNICDisableAction(AbstractAction):
|
||||
# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
# super().__init__(manager=manager)
|
||||
# self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
|
||||
# def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
# return [
|
||||
# "network",
|
||||
# "node",
|
||||
# self.manager.get_node_uuid_by_idx(node_idx=node_id),
|
||||
# "nic",
|
||||
# self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
|
||||
# "disable",
|
||||
# ]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING, Sequence, Tuple
|
||||
|
||||
from gym import spaces
|
||||
from pydantic import BaseModel
|
||||
@@ -15,7 +15,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any:
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
"""
|
||||
Access an item from a deeply dictionary with a list of keys.
|
||||
|
||||
@@ -29,12 +29,13 @@ def access_from_nested_dict(dictionary: Dict, keys: List[Hashable]) -> Any:
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
if len(keys) == 0:
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
k = keys.pop(0)
|
||||
k = key_list.pop(0)
|
||||
if k not in dictionary:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
return access_from_nested_dict(dictionary[k], keys)
|
||||
return access_from_nested_dict(dictionary[k], key_list)
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
@@ -66,7 +67,7 @@ class AbstractObservation(ABC):
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""
|
||||
_summary_
|
||||
|
||||
@@ -79,7 +80,7 @@ class FileObservation(AbstractObservation):
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
|
||||
@@ -104,7 +105,7 @@ class ServiceObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
@@ -115,7 +116,7 @@ class ServiceObservation(AbstractObservation):
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if self.where is None:
|
||||
@@ -124,7 +125,7 @@ class ServiceObservation(AbstractObservation):
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": service_state["operating_status"], "health_status": service_state["health_status"]}
|
||||
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -132,7 +133,9 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]=None):
|
||||
return cls(where=parent_where+["services",session.ref_map_services[config['service_ref']]])
|
||||
return cls(
|
||||
where=parent_where+["services",session.ref_map_services[config['service_ref']].uuid]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -140,7 +143,7 @@ class LinkObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
@@ -151,7 +154,7 @@ class LinkObservation(AbstractObservation):
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if self.where is None:
|
||||
@@ -180,7 +183,7 @@ class LinkObservation(AbstractObservation):
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
def __init__(self, where: Optional[List[str]] = None, files: List[FileObservation] = []) -> None:
|
||||
def __init__(self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = []) -> None:
|
||||
"""Initialise folder Observation, including files inside of the folder.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
@@ -199,7 +202,7 @@ class FolderObservation(AbstractObservation):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
|
||||
@@ -246,9 +249,9 @@ class FolderObservation(AbstractObservation):
|
||||
class NicObservation(AbstractObservation):
|
||||
default_observation: spaces.Space = {"nic_status": 0}
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None) -> None:
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if self.where is None:
|
||||
@@ -271,7 +274,7 @@ class NicObservation(AbstractObservation):
|
||||
class NodeObservation(AbstractObservation):
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[List[str]] = None,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
nics: List[NicObservation] = [],
|
||||
@@ -298,7 +301,7 @@ class NodeObservation(AbstractObservation):
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
self.folders: List[FolderObservation] = folders
|
||||
@@ -371,10 +374,10 @@ class AclObservation(AbstractObservation):
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[List[str]] = None, num_rules: int = 10
|
||||
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[Tuple[str]] = None, num_rules: int = 10
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
@@ -403,6 +406,8 @@ class AclObservation(AbstractObservation):
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
|
||||
#TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
obs["RULES"] = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
@@ -466,7 +471,7 @@ class AclObservation(AbstractObservation):
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
where=["network", "nodes", router_uuid])
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"])
|
||||
|
||||
|
||||
|
||||
@@ -498,7 +503,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
where:Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
@@ -517,11 +522,10 @@ class UC2BlueObservation(AbstractObservation):
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
|
||||
obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs['ACL'] = {self.acl.observe(state)}
|
||||
obs['ICS'] = {self.ics.observe(state)}
|
||||
obs['ACL'] = self.acl.observe(state)
|
||||
obs['ICS'] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@@ -546,7 +550,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
acl = AclObservation.from_config(config=acl_config, session=session)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(ics_config)
|
||||
ics = ICSObservation.from_config(config=ics_config, session=session)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network'])
|
||||
return new
|
||||
|
||||
|
||||
@@ -111,11 +111,11 @@ class AccessControlList(SimComponent):
|
||||
Action(
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
IPProtocol[request[1]],
|
||||
IPv4Address[request[2]],
|
||||
Port[request[3]],
|
||||
IPv4Address[request[4]],
|
||||
Port[request[5]],
|
||||
None if request[1] is "ALL" else IPProtocol[request[1]],
|
||||
IPv4Address(request[2]),
|
||||
None if request[3] is "ALL" else Port[request[3]],
|
||||
IPv4Address(request[4]),
|
||||
None if request[5] is "ALL" else Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
),
|
||||
|
||||
@@ -55,12 +55,11 @@ class Switch(Node):
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
"""
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"num_ports": self.num_ports, # redundant?
|
||||
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
|
||||
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
|
||||
}
|
||||
state = super().describe_state()
|
||||
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
|
||||
state["num_ports"]= self.num_ports # redundant?
|
||||
state["mac_address_table"]= {mac: port for mac, port in self.mac_address_table.items()}
|
||||
return state
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user