#2417 Add categorisation and updated new configs from merge

This commit is contained in:
Marek Wolan
2024-04-01 22:03:28 +01:00
parent e4300faa1c
commit d2c7ae481c
9 changed files with 188 additions and 112 deletions

View File

@@ -43,6 +43,26 @@ class FileObservation(AbstractObservation, identifier="FILE"):
if self.include_num_access:
self.default_observation["num_access"] = 0
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_access(self, num_access: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
return 3
elif num_access > self.med_threshold:
return 2
elif num_access > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -57,8 +77,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
return self.default_observation
obs = {"health_status": file_state["visible_status"]}
if self.include_num_access:
obs["num_access"] = file_state["num_access"]
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@property

View File

@@ -214,9 +214,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Constructed firewall observation instance.
:rtype: FirewallObservation
"""
where = parent_where + ["nodes", config.hostname]
return cls(
where=where,
where=parent_where + ["nodes", config.hostname],
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,

View File

@@ -185,9 +185,11 @@ class ObservationManager:
Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
If None, a blank observation space is created.
Otherwise, this must be a Dict with a type field and options field.
type: string that corresponds to one of the observation identifiers that are provided when subclassing
AbstractObservation
options: this must adhere to the chosen observation type's ConfigSchema nested class.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame

View File

@@ -98,6 +98,26 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
return 3
elif num_executions > self.med_threshold:
return 2
elif num_executions > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -113,7 +133,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
return {
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"num_executions": application_state["num_executions"],
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
}
@property