Fix formatting with precommit
This commit is contained in:
@@ -4,8 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -90,56 +91,56 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str
|
||||
@@ -153,25 +154,25 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "restore"
|
||||
|
||||
@@ -293,7 +294,7 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
) -> List[str]:
|
||||
if permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif permission == 1:
|
||||
permission_str = "ALLOW"
|
||||
elif permission == 2:
|
||||
@@ -302,30 +303,30 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
if protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id-2)
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id in [0,1]:
|
||||
if source_ip_id in [0, 1]:
|
||||
src_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id-2)
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id-2)
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id in (0,1):
|
||||
if dest_ip_id in (0, 1):
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
@@ -394,6 +395,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
# class NetworkNICDisableAction(AbstractAction):
|
||||
# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
# super().__init__(manager=manager)
|
||||
@@ -495,7 +497,7 @@ class ActionManager:
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips": len(self.ip_address_list),
|
||||
"max_acl_rules":max_acl_rules,
|
||||
"max_acl_rules": max_acl_rules,
|
||||
"max_nics_per_node": max_nics_per_node,
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
@@ -507,8 +509,8 @@ class ActionManager:
|
||||
# option_2: value2
|
||||
# where `type` decides which AbstractAction subclass should be used
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get('type')
|
||||
act_options = act_spec.get('options', {})
|
||||
act_type = act_spec.get("type")
|
||||
act_options = act_spec.get("options", {})
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
@@ -555,13 +557,12 @@ class ActionManager:
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(
|
||||
act_name, {param_names[i]:param_combinations[j][i] for i in range(len(param_names))}
|
||||
) for j in range(len(param_combinations))]
|
||||
)
|
||||
|
||||
return {i:p for i,p in enumerate(all_action_possibilities)}
|
||||
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
|
||||
for j in range(len(param_combinations))
|
||||
]
|
||||
)
|
||||
|
||||
return {i: p for i, p in enumerate(all_action_possibilities)}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format"""
|
||||
@@ -627,7 +628,7 @@ class ActionManager:
|
||||
session=session,
|
||||
actions=cfg["action_list"],
|
||||
# node_uuids=cfg["options"]["node_uuids"],
|
||||
**cfg['options'],
|
||||
**cfg["options"],
|
||||
protocols=session.options.protocols,
|
||||
ports=session.options.ports,
|
||||
ip_address_list=None,
|
||||
|
||||
@@ -23,7 +23,7 @@ class AbstractAgent(ABC):
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
self.agent_name:str = agent_name or "unnamed_agent"
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
@@ -46,9 +46,9 @@ class AbstractAgent(ABC):
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {} )
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
def format_request(self, action:Tuple[str,Dict], options:Dict[str, int]) -> List[str]:
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING, Sequence, Tuple
|
||||
from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
||||
|
||||
from gym import spaces
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
@@ -29,7 +30,7 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
k = key_list.pop(0)
|
||||
@@ -58,7 +59,7 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config:Dict, session:"PrimaiteSession"):
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
|
||||
@@ -98,7 +99,7 @@ class FileObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where=None):
|
||||
return cls(where=parent_where+["files", config["file_name"]])
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
@@ -132,11 +133,8 @@ class ServiceObservation(AbstractObservation):
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]=None):
|
||||
return cls(
|
||||
where=parent_where+["services",session.ref_map_services[config['service_ref']].uuid]
|
||||
)
|
||||
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None):
|
||||
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
@@ -179,7 +177,7 @@ class LinkObservation(AbstractObservation):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
return cls(where=['network','links', session.ref_map_links[config['link_ref']]])
|
||||
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
@@ -237,13 +235,13 @@ class FolderObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]):
|
||||
where = parent_where + ["folders", config['folder_name']]
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]):
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
|
||||
|
||||
return cls(where=where,files=files)
|
||||
return cls(where=where, files=files)
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
@@ -267,7 +265,7 @@ class NicObservation(AbstractObservation):
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]):
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]):
|
||||
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
|
||||
|
||||
|
||||
@@ -278,7 +276,7 @@ class NodeObservation(AbstractObservation):
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
nics: List[NicObservation] = [],
|
||||
logon_status:bool=False
|
||||
logon_status: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
@@ -306,7 +304,7 @@ class NodeObservation(AbstractObservation):
|
||||
self.services: List[ServiceObservation] = services
|
||||
self.folders: List[FolderObservation] = folders
|
||||
self.nics: List[NicObservation] = nics
|
||||
self.logon_status:bool=logon_status
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
@@ -315,7 +313,7 @@ class NodeObservation(AbstractObservation):
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation['logon_status']=0
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if self.where is None:
|
||||
@@ -332,7 +330,7 @@ class NodeObservation(AbstractObservation):
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
|
||||
if self.logon_status:
|
||||
obs['logon_status'] = 0
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@@ -345,26 +343,28 @@ class NodeObservation(AbstractObservation):
|
||||
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape['logon_status'] = spaces.Discrete(3)
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]= None) -> "NodeObservation":
|
||||
node_uuid = session.ref_map_nodes[config['node_ref']]
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
|
||||
) -> "NodeObservation":
|
||||
node_uuid = session.ref_map_nodes[config["node_ref"]]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_uuid]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_uuid]
|
||||
|
||||
svc_configs = config.get('services', {})
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get('folders', {})
|
||||
folders = [FolderObservation.from_config(config=c,session=session, parent_where=where) for c in folder_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [FolderObservation.from_config(config=c, session=session, parent_where=where) for c in folder_configs]
|
||||
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_configs = [{'nic_uuid':n for n in nic_uuids }] if nic_uuids else []
|
||||
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
|
||||
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get('logon_status',False)
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(where=where, services=services, folders=folders, nics=nics, logon_status=logon_status)
|
||||
|
||||
|
||||
@@ -374,7 +374,12 @@ class AclObservation(AbstractObservation):
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[Tuple[str]] = None, num_rules: int = 10
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: list[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
@@ -386,16 +391,18 @@ class AclObservation(AbstractObservation):
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
"RULES": {i+ 1:{
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
"RULES": {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -406,8 +413,7 @@ class AclObservation(AbstractObservation):
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
|
||||
#TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
obs["RULES"] = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
@@ -439,7 +445,8 @@ class AclObservation(AbstractObservation):
|
||||
{
|
||||
"RULE": spaces.Dict(
|
||||
{
|
||||
i + 1: spaces.Dict(
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
@@ -460,23 +467,23 @@ class AclObservation(AbstractObservation):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
|
||||
node_ip_to_idx = {}
|
||||
for node_idx, node_cfg in enumerate(config['node_order']):
|
||||
for node_idx, node_cfg in enumerate(config["node_order"]):
|
||||
n_ref = node_cfg["node_ref"]
|
||||
n_obj = session.simulation.network.nodes[session.ref_map_nodes[n_ref]]
|
||||
for nic_uuid, nic_obj in n_obj.nics.items():
|
||||
node_ip_to_idx[nic_obj.ip_address] = node_idx + 2
|
||||
|
||||
router_uuid = session.ref_map_nodes[config['router_node_ref']]
|
||||
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"])
|
||||
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
def __init__(self, where:Optional[List[str]]=None):
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
@@ -487,20 +494,22 @@ class NullObservation(AbstractObservation):
|
||||
return spaces.Dict({})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config:Dict, session:Optional["PrimaiteSession"]=None) -> "NullObservation":
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
|
||||
return cls()
|
||||
|
||||
class ICSObservation(NullObservation): pass
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
pass
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where:Optional[List[str]] = None,
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
@@ -510,36 +519,38 @@ class UC2BlueObservation(AbstractObservation):
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation : Dict = {
|
||||
"NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)},
|
||||
"LINKS": {i+1: l.default_observation for i,l in enumerate(self.links)},
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state:Dict) -> Dict:
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs['ACL'] = self.acl.observe(state)
|
||||
obs['ICS'] = self.ics.observe(state)
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({
|
||||
"NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i+1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
})
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config:Dict, session:"PrimaiteSession"):
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=n, session=session) for n in node_configs]
|
||||
|
||||
@@ -551,18 +562,18 @@ class UC2BlueObservation(AbstractObservation):
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, session=session)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network'])
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
def __init__(self, nodes:List[NodeObservation], where:Optional[List[str]] = None) -> None:
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where:Optional[List[str]] = where
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation : Dict = {
|
||||
"NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)},
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
@@ -570,14 +581,16 @@ class UC2RedObservation(AbstractObservation):
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs['NODES'] = {i+1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({
|
||||
"NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}),
|
||||
})
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
@@ -586,7 +599,9 @@ class UC2RedObservation(AbstractObservation):
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation): pass
|
||||
class UC2GreenObservation(NullObservation):
|
||||
pass
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
"""
|
||||
@@ -603,7 +618,7 @@ class ObservationSpace:
|
||||
# what this class does:
|
||||
# keep a list of observations
|
||||
# create observations for an actor from the config
|
||||
def __init__(self, observation:AbstractObservation) -> None:
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
self.obs: AbstractObservation = observation
|
||||
|
||||
def observe(self, state) -> Dict:
|
||||
@@ -614,12 +629,12 @@ class ObservationSpace:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config:Dict, session:"PrimaiteSession") -> "ObservationSpace":
|
||||
if config['type'] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get('options',{}), session=session))
|
||||
elif config['type'] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get('options',{}), session=session))
|
||||
elif config['type'] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options",{}), session=session))
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
class AbstractReward():
|
||||
|
||||
class AbstractReward:
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state:Dict) -> float:
|
||||
def calculate(self, state: Dict) -> float:
|
||||
return 0.3
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
def calculate(self, state: Dict) -> float:
|
||||
return -0.1
|
||||
|
||||
class RewardFunction():
|
||||
__rew_class_identifiers:Dict[str,type[AbstractReward]] = {
|
||||
"DUMMY" : DummyReward
|
||||
}
|
||||
def __init__(self, reward_function:AbstractReward):
|
||||
|
||||
class RewardFunction:
|
||||
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {"DUMMY": DummyReward}
|
||||
|
||||
def __init__(self, reward_function: AbstractReward):
|
||||
self.reward: AbstractReward = reward_function
|
||||
|
||||
def calculate(self, state:Dict) -> float:
|
||||
def calculate(self, state: Dict) -> float:
|
||||
return self.reward.calculate(state)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg:Dict) -> "RewardFunction":
|
||||
for rew_component_cfg in cfg['reward_components']:
|
||||
rew_type = rew_component_cfg['type']
|
||||
def from_config(cls, cfg: Dict) -> "RewardFunction":
|
||||
for rew_component_cfg in cfg["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
rew_component = cls.__rew_class_identifiers[rew_type]()
|
||||
new = cls(reward_function=rew_component)
|
||||
return new
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import (
|
||||
@@ -37,14 +38,12 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -91,7 +90,7 @@ class PrimaiteSession:
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
@@ -106,8 +105,8 @@ class PrimaiteSession:
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports = cfg['game_config']['ports'],
|
||||
protocols = cfg['game_config']['protocols'],
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
@@ -230,7 +229,7 @@ class PrimaiteSession:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space=ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
|
||||
"""
|
||||
# if observation_space_cfg is None:
|
||||
@@ -331,23 +330,23 @@ class PrimaiteSession:
|
||||
"""
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg['options']['node_uuids'] = []
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
for action_node_option in action_space_cfg.get('options',{}).pop('nodes', {}):
|
||||
if 'node_ref' in action_node_option:
|
||||
node_uuid = sess.ref_map_nodes[action_node_option['node_ref']]
|
||||
action_space_cfg['options']['node_uuids'].append(node_uuid)
|
||||
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
|
||||
if "node_ref" in action_node_option:
|
||||
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
|
||||
action_space_cfg["options"]["node_uuids"].append(node_uuid)
|
||||
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
|
||||
# we will pass node_uuids as a part of the action space config.
|
||||
# However, it's not possible to specify the node uuids directly in the config, as they are generated
|
||||
# dynamically, so we have to translate node references to uuids before passing this config on.
|
||||
|
||||
if 'action_list' in action_space_cfg:
|
||||
for action_config in action_space_cfg['action_list']:
|
||||
if 'options' in action_config:
|
||||
if 'target_router_ref' in action_config['options']:
|
||||
_target = action_config['options']['target_router_ref']
|
||||
action_config['options']['target_router_uuid'] = sess.ref_map_nodes[_target]
|
||||
if "action_list" in action_space_cfg:
|
||||
for action_config in action_space_cfg["action_list"]:
|
||||
if "options" in action_config:
|
||||
if "target_router_ref" in action_config["options"]:
|
||||
_target = action_config["options"]["target_router_ref"]
|
||||
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
|
||||
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
|
||||
@@ -357,16 +356,30 @@ class PrimaiteSession:
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function)
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function)
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function)
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
|
||||
return sess
|
||||
|
||||
@@ -57,8 +57,8 @@ class Switch(Node):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
|
||||
state["num_ports"]= self.num_ports # redundant?
|
||||
state["mac_address_table"]= {mac: port for mac, port in self.mac_address_table.items()}
|
||||
state["num_ports"] = self.num_ports # redundant?
|
||||
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
|
||||
return state
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
|
||||
Reference in New Issue
Block a user