Fix formatting with precommit

This commit is contained in:
Marek Wolan
2023-10-09 18:35:30 +01:00
parent c9bc8fbf3d
commit 91f06c15f6
6 changed files with 189 additions and 160 deletions

View File

@@ -4,8 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gym import spaces
from primaite.simulator.sim_container import Simulation
from primaite import getLogger
from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
@@ -90,56 +91,56 @@ class NodeServiceAbstractAction(AbstractAction):
class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
class NodeFolderAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
@@ -153,25 +154,25 @@ class NodeFolderAbstractAction(AbstractAction):
class NodeFolderScanAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "restore"
@@ -293,7 +294,7 @@ class NetworkACLAddRuleAction(AbstractAction):
) -> List[str]:
if permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
permission_str = "ALLOW"
elif permission == 2:
@@ -302,30 +303,30 @@ class NetworkACLAddRuleAction(AbstractAction):
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(protocol_id-2)
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if source_ip_id in [0,1]:
if source_ip_id in [0, 1]:
src_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id-2)
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id-2)
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_ip_id in (0,1):
if dest_ip_id in (0, 1):
dst_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
# subtract 2 to account for UNUSED=0, and ALL=1
@@ -394,6 +395,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
# class NetworkNICDisableAction(AbstractAction):
# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
# super().__init__(manager=manager)
@@ -495,7 +497,7 @@ class ActionManager:
"num_protocols": len(self.protocols),
"num_ports": len(self.protocols),
"num_ips": len(self.ip_address_list),
"max_acl_rules":max_acl_rules,
"max_acl_rules": max_acl_rules,
"max_nics_per_node": max_nics_per_node,
}
self.actions: Dict[str, AbstractAction] = {}
@@ -507,8 +509,8 @@ class ActionManager:
# option_2: value2
# where `type` decides which AbstractAction subclass should be used
# and `options` is an optional dict of options to pass to the init method of the action class
act_type = act_spec.get('type')
act_options = act_spec.get('options', {})
act_type = act_spec.get("type")
act_options = act_spec.get("options", {})
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
self.action_map: Dict[int, Tuple[str, Dict]] = {}
@@ -555,13 +557,12 @@ class ActionManager:
param_combinations = list(itertools.product(*possibilities))
all_action_possibilities.extend(
[
(
act_name, {param_names[i]:param_combinations[j][i] for i in range(len(param_names))}
) for j in range(len(param_combinations))]
)
return {i:p for i,p in enumerate(all_action_possibilities)}
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
for j in range(len(param_combinations))
]
)
return {i: p for i, p in enumerate(all_action_possibilities)}
def get_action(self, action: int) -> Tuple[str, Dict]:
"""Produce action in CAOS format"""
@@ -627,7 +628,7 @@ class ActionManager:
session=session,
actions=cfg["action_list"],
# node_uuids=cfg["options"]["node_uuids"],
**cfg['options'],
**cfg["options"],
protocols=session.options.protocols,
ports=session.options.ports,
ip_address_list=None,

View File

@@ -23,7 +23,7 @@ class AbstractAgent(ABC):
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
self.agent_name:str = agent_name or "unnamed_agent"
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
@@ -46,9 +46,9 @@ class AbstractAgent(ABC):
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {} )
return ("DO_NOTHING", {})
def format_request(self, action:Tuple[str,Dict], options:Dict[str, int]) -> List[str]:
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""

View File

@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Hashable, List, Optional, TYPE_CHECKING, Sequence, Tuple
from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, TYPE_CHECKING
from gym import spaces
from pydantic import BaseModel
from primaite.simulator.sim_container import Simulation
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
@@ -29,7 +30,7 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
:return: The value in the dictionary
:rtype: Any
"""
key_list = [*keys] # copy keys to a new list to prevent editing original list
key_list = [*keys] # copy keys to a new list to prevent editing original list
if len(key_list) == 0:
return dictionary
k = key_list.pop(0)
@@ -58,7 +59,7 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config:Dict, session:"PrimaiteSession"):
def from_config(cls, config: Dict, session: "PrimaiteSession"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
@@ -98,7 +99,7 @@ class FileObservation(AbstractObservation):
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where=None):
return cls(where=parent_where+["files", config["file_name"]])
return cls(where=parent_where + ["files", config["file_name"]])
class ServiceObservation(AbstractObservation):
@@ -132,11 +133,8 @@ class ServiceObservation(AbstractObservation):
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]=None):
return cls(
where=parent_where+["services",session.ref_map_services[config['service_ref']].uuid]
)
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None):
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
@@ -179,7 +177,7 @@ class LinkObservation(AbstractObservation):
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
return cls(where=['network','links', session.ref_map_links[config['link_ref']]])
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
@@ -237,13 +235,13 @@ class FolderObservation(AbstractObservation):
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]):
where = parent_where + ["folders", config['folder_name']]
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]):
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
return cls(where=where,files=files)
return cls(where=where, files=files)
class NicObservation(AbstractObservation):
@@ -267,7 +265,7 @@ class NicObservation(AbstractObservation):
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]):
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]):
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
@@ -278,7 +276,7 @@ class NodeObservation(AbstractObservation):
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
nics: List[NicObservation] = [],
logon_status:bool=False
logon_status: bool = False,
) -> None:
"""
Configurable observation for a node in the simulation.
@@ -306,7 +304,7 @@ class NodeObservation(AbstractObservation):
self.services: List[ServiceObservation] = services
self.folders: List[FolderObservation] = folders
self.nics: List[NicObservation] = nics
self.logon_status:bool=logon_status
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
@@ -315,7 +313,7 @@ class NodeObservation(AbstractObservation):
"operating_status": 0,
}
if self.logon_status:
self.default_observation['logon_status']=0
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
if self.where is None:
@@ -332,7 +330,7 @@ class NodeObservation(AbstractObservation):
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.logon_status:
obs['logon_status'] = 0
obs["logon_status"] = 0
return obs
@@ -345,26 +343,28 @@ class NodeObservation(AbstractObservation):
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
}
if self.logon_status:
space_shape['logon_status'] = spaces.Discrete(3)
space_shape["logon_status"] = spaces.Discrete(3)
return spaces.Dict(space_shape)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where:Optional[List[str]]= None) -> "NodeObservation":
node_uuid = session.ref_map_nodes[config['node_ref']]
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
) -> "NodeObservation":
node_uuid = session.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get('services', {})
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
folder_configs = config.get('folders', {})
folders = [FolderObservation.from_config(config=c,session=session, parent_where=where) for c in folder_configs]
folder_configs = config.get("folders", {})
folders = [FolderObservation.from_config(config=c, session=session, parent_where=where) for c in folder_configs]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{'nic_uuid':n for n in nic_uuids }] if nic_uuids else []
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
logon_status = config.get('logon_status',False)
logon_status = config.get("logon_status", False)
return cls(where=where, services=services, folders=folders, nics=nics, logon_status=logon_status)
@@ -374,7 +374,12 @@ class AclObservation(AbstractObservation):
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self, node_ip_to_id: Dict[str,int], ports: List[int], protocols: list[str], where: Optional[Tuple[str]] = None, num_rules: int = 10
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: list[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
super().__init__()
self.where: Optional[Tuple[str]] = where
@@ -386,16 +391,18 @@ class AclObservation(AbstractObservation):
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
"RULES": {i+ 1:{
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
"RULES": {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
for i in range(self.num_rules)
}
}
@@ -406,8 +413,7 @@ class AclObservation(AbstractObservation):
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
#TODO: what if the ACL has more rules than num of max rules for obs space
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
obs["RULES"] = {}
for i, rule_state in acl_state.items():
@@ -439,7 +445,8 @@ class AclObservation(AbstractObservation):
{
"RULE": spaces.Dict(
{
i + 1: spaces.Dict(
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
@@ -460,23 +467,23 @@ class AclObservation(AbstractObservation):
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
node_ip_to_idx = {}
for node_idx, node_cfg in enumerate(config['node_order']):
for node_idx, node_cfg in enumerate(config["node_order"]):
n_ref = node_cfg["node_ref"]
n_obj = session.simulation.network.nodes[session.ref_map_nodes[n_ref]]
for nic_uuid, nic_obj in n_obj.nics.items():
node_ip_to_idx[nic_obj.ip_address] = node_idx + 2
router_uuid = session.ref_map_nodes[config['router_node_ref']]
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"])
where=["network", "nodes", router_uuid, "acl", "acl"],
)
class NullObservation(AbstractObservation):
def __init__(self, where:Optional[List[str]]=None):
def __init__(self, where: Optional[List[str]] = None):
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
@@ -487,20 +494,22 @@ class NullObservation(AbstractObservation):
return spaces.Dict({})
@classmethod
def from_config(cls, config:Dict, session:Optional["PrimaiteSession"]=None) -> "NullObservation":
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
return cls()
class ICSObservation(NullObservation): pass
class ICSObservation(NullObservation):
pass
class UC2BlueObservation(AbstractObservation):
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where:Optional[List[str]] = None,
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
super().__init__()
self.where: Optional[Tuple[str]] = where
@@ -510,36 +519,38 @@ class UC2BlueObservation(AbstractObservation):
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation : Dict = {
"NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)},
"LINKS": {i+1: l.default_observation for i,l in enumerate(self.links)},
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state:Dict) -> Dict:
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
obs = {}
obs['NODES'] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs['LINKS'] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs['ACL'] = self.acl.observe(state)
obs['ICS'] = self.ics.observe(state)
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict({
"NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i+1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
})
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config:Dict, session:"PrimaiteSession"):
def from_config(cls, config: Dict, session: "PrimaiteSession"):
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=n, session=session) for n in node_configs]
@@ -551,18 +562,18 @@ class UC2BlueObservation(AbstractObservation):
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=['network'])
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
def __init__(self, nodes:List[NodeObservation], where:Optional[List[str]] = None) -> None:
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where:Optional[List[str]] = where
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation : Dict = {
"NODES": {i+1: n.default_observation for i,n in enumerate(self.nodes)},
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
@@ -570,14 +581,16 @@ class UC2RedObservation(AbstractObservation):
return self.default_observation
obs = {}
obs['NODES'] = {i+1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict({
"NODES": spaces.Dict({i+1: node.space for i, node in enumerate(self.nodes)}),
})
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
@@ -586,7 +599,9 @@ class UC2RedObservation(AbstractObservation):
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation): pass
class UC2GreenObservation(NullObservation):
pass
class ObservationSpace:
"""
@@ -603,7 +618,7 @@ class ObservationSpace:
# what this class does:
# keep a list of observations
# create observations for an actor from the config
def __init__(self, observation:AbstractObservation) -> None:
def __init__(self, observation: AbstractObservation) -> None:
self.obs: AbstractObservation = observation
def observe(self, state) -> Dict:
@@ -614,12 +629,12 @@ class ObservationSpace:
return self.obs.space
@classmethod
def from_config(cls, config:Dict, session:"PrimaiteSession") -> "ObservationSpace":
if config['type'] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get('options',{}), session=session))
elif config['type'] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get('options',{}), session=session))
elif config['type'] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options",{}), session=session))
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
else:
raise ValueError("Observation space type invalid")

View File

@@ -1,34 +1,34 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
class AbstractReward():
class AbstractReward:
def __init__(self):
...
@abstractmethod
def calculate(self, state:Dict) -> float:
def calculate(self, state: Dict) -> float:
return 0.3
class DummyReward(AbstractReward):
class DummyReward(AbstractReward):
def calculate(self, state: Dict) -> float:
return -0.1
class RewardFunction():
__rew_class_identifiers:Dict[str,type[AbstractReward]] = {
"DUMMY" : DummyReward
}
def __init__(self, reward_function:AbstractReward):
class RewardFunction:
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {"DUMMY": DummyReward}
def __init__(self, reward_function: AbstractReward):
self.reward: AbstractReward = reward_function
def calculate(self, state:Dict) -> float:
def calculate(self, state: Dict) -> float:
return self.reward.calculate(state)
@classmethod
def from_config(cls, cfg:Dict) -> "RewardFunction":
for rew_component_cfg in cfg['reward_components']:
rew_type = rew_component_cfg['type']
def from_config(cls, cfg: Dict) -> "RewardFunction":
for rew_component_cfg in cfg["reward_components"]:
rew_type = rew_component_cfg["type"]
rew_component = cls.__rew_class_identifiers[rew_type]()
new = cls(reward_function=rew_component)
return new

View File

@@ -10,6 +10,7 @@ from typing import Dict, List
from pydantic import BaseModel
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.observations import (
@@ -37,14 +38,12 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from primaite import getLogger
_LOGGER = getLogger(__name__)
@@ -91,7 +90,7 @@ class PrimaiteSession:
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
@@ -106,8 +105,8 @@ class PrimaiteSession:
def from_config(cls, cfg: dict) -> "PrimaiteSession":
sess = cls()
sess.options = PrimaiteSessionOptions(
ports = cfg['game_config']['ports'],
protocols = cfg['game_config']['protocols'],
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sim = sess.simulation
net = sim.network
@@ -230,7 +229,7 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space=ObservationSpace.from_config(observation_space_cfg, sess)
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
"""
# if observation_space_cfg is None:
@@ -331,23 +330,23 @@ class PrimaiteSession:
"""
# CREATE ACTION SPACE
action_space_cfg['options']['node_uuids'] = []
action_space_cfg["options"]["node_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get('options',{}).pop('nodes', {}):
if 'node_ref' in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option['node_ref']]
action_space_cfg['options']['node_uuids'].append(node_uuid)
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
# dynamically, so we have to translate node references to uuids before passing this config on.
if 'action_list' in action_space_cfg:
for action_config in action_space_cfg['action_list']:
if 'options' in action_config:
if 'target_router_ref' in action_config['options']:
_target = action_config['options']['target_router_ref']
action_config['options']['target_router_uuid'] = sess.ref_map_nodes[_target]
if "action_list" in action_space_cfg:
for action_config in action_space_cfg["action_list"]:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
@@ -357,16 +356,30 @@ class PrimaiteSession:
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function)
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function)
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(agent_name=agent_cfg['ref'], action_space=action_space, observation_space=obs_space, reward_function=rew_function)
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
else:
print("agent type not found")
return sess

View File

@@ -57,8 +57,8 @@ class Switch(Node):
"""
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
state["num_ports"]= self.num_ports # redundant?
state["mac_address_table"]= {mac: port for mac, port in self.mac_address_table.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):