Align observations to Common approach
This commit is contained in:
@@ -2,10 +2,10 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 4
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 1
|
||||
n_eval_steps: 128
|
||||
n_learn_episodes: 1
|
||||
n_learn_steps: 8
|
||||
n_eval_episodes: 0
|
||||
n_eval_steps: 8
|
||||
|
||||
|
||||
game_config:
|
||||
@@ -39,10 +39,10 @@ game_config:
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
@@ -93,9 +93,9 @@ game_config:
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -113,9 +113,10 @@ game_config:
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 2
|
||||
num_folders_per_node: 2
|
||||
num_files_per_folder: 2
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
@@ -148,6 +149,8 @@ game_config:
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
|
||||
@@ -167,7 +167,7 @@ class ServiceObservation(AbstractObservation):
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
@@ -206,7 +206,7 @@ class LinkObservation(AbstractObservation):
|
||||
utilisation_category = int(utilisation_fraction * 10) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"protocols": {"all": {"load": utilisation_category}}}
|
||||
return {"PROTOCOLS": {"ALL": utilisation_category}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -215,7 +215,7 @@ class LinkObservation(AbstractObservation):
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})})
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
|
||||
@@ -264,7 +264,6 @@ class FolderObservation(AbstractObservation):
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warn(msg)
|
||||
raise UserWarning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
@@ -407,6 +406,7 @@ class NodeObservation(AbstractObservation):
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
@@ -440,18 +440,25 @@ class NodeObservation(AbstractObservation):
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warn(msg)
|
||||
raise UserWarning(msg)
|
||||
# truncate service list
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation())
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.nics: List[NicObservation] = nics
|
||||
while len(self.nics) < num_nics_per_node:
|
||||
self.nics.append(NicObservation())
|
||||
while len(self.nics) > num_nics_per_node:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
@@ -512,6 +519,7 @@ class NodeObservation(AbstractObservation):
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
@@ -562,6 +570,7 @@ class NodeObservation(AbstractObservation):
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
|
||||
|
||||
@@ -605,19 +614,17 @@ class AclObservation(AbstractObservation):
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
"RULES": {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
@@ -636,10 +643,9 @@ class AclObservation(AbstractObservation):
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
obs["RULES"] = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
if rule_state is None:
|
||||
obs["RULES"][i + 1] = {
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
@@ -649,7 +655,7 @@ class AclObservation(AbstractObservation):
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs["RULES"][i + 1] = {
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
@@ -669,24 +675,20 @@ class AclObservation(AbstractObservation):
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"RULES": spaces.Dict(
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -701,6 +703,7 @@ class AclObservation(AbstractObservation):
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_ref"]
|
||||
@@ -715,6 +718,7 @@ class AclObservation(AbstractObservation):
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
@@ -846,6 +850,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
@@ -853,6 +858,7 @@ class UC2BlueObservation(AbstractObservation):
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user