Align observations to Common approach

This commit is contained in:
Marek Wolan
2023-10-24 11:07:25 +01:00
parent d4eee36b7b
commit 6b7c483a67
2 changed files with 59 additions and 50 deletions

View File

@@ -2,10 +2,10 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 4
n_learn_steps: 128
n_eval_episodes: 1
n_eval_steps: 128
n_learn_episodes: 1
n_learn_steps: 8
n_eval_episodes: 0
n_eval_steps: 8
game_config:
@@ -39,10 +39,10 @@ game_config:
options:
nodes:
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
@@ -93,9 +93,9 @@ game_config:
options:
nodes:
- node_ref: client_1
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
@@ -113,9 +113,10 @@ game_config:
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 2
num_folders_per_node: 2
num_files_per_folder: 2
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
@@ -148,6 +149,8 @@ game_config:
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller

View File

@@ -167,7 +167,7 @@ class ServiceObservation(AbstractObservation):
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
@@ -206,7 +206,7 @@ class LinkObservation(AbstractObservation):
utilisation_category = int(utilisation_fraction * 10) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"protocols": {"all": {"load": utilisation_category}}}
return {"PROTOCOLS": {"ALL": utilisation_category}}
@property
def space(self) -> spaces.Space:
@@ -215,7 +215,7 @@ class LinkObservation(AbstractObservation):
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})})
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
@@ -264,7 +264,6 @@ class FolderObservation(AbstractObservation):
truncated_file = self.files.pop()
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
_LOGGER.warn(msg)
raise UserWarning(msg)
self.default_observation = {
"health_status": 0,
@@ -407,6 +406,7 @@ class NodeObservation(AbstractObservation):
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> None:
"""
Configurable observation for a node in the simulation.
@@ -440,18 +440,25 @@ class NodeObservation(AbstractObservation):
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warn(msg)
raise UserWarning(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation())
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warn(msg)
self.nics: List[NicObservation] = nics
while len(self.nics) < num_nics_per_node:
self.nics.append(NicObservation())
while len(self.nics) > num_nics_per_node:
truncated_nic = self.nics.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warn(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
@@ -512,6 +519,7 @@ class NodeObservation(AbstractObservation):
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
@@ -562,6 +570,7 @@ class NodeObservation(AbstractObservation):
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
@@ -605,19 +614,17 @@ class AclObservation(AbstractObservation):
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
"RULES": {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
@@ -636,10 +643,9 @@ class AclObservation(AbstractObservation):
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
obs["RULES"] = {}
for i, rule_state in acl_state.items():
if rule_state is None:
obs["RULES"][i + 1] = {
obs[i + 1] = {
"position": i,
"permission": 0,
"source_node_id": 0,
@@ -649,7 +655,7 @@ class AclObservation(AbstractObservation):
"protocol": 0,
}
else:
obs["RULES"][i + 1] = {
obs[i + 1] = {
"position": i,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
@@ -669,24 +675,20 @@ class AclObservation(AbstractObservation):
"""
return spaces.Dict(
{
"RULES": spaces.Dict(
i
+ 1: spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@@ -701,6 +703,7 @@ class AclObservation(AbstractObservation):
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
@@ -715,6 +718,7 @@ class AclObservation(AbstractObservation):
ports=session.options.ports,
protocols=session.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
num_rules=max_acl_rules,
)
@@ -846,6 +850,7 @@ class UC2BlueObservation(AbstractObservation):
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
@@ -853,6 +858,7 @@ class UC2BlueObservation(AbstractObservation):
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]