Configure observation type MULTIDISCRETE
This commit is contained in:
@@ -83,6 +83,13 @@ class ActionType(Enum):
|
||||
ACL = 1
|
||||
|
||||
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
BOX = 0
|
||||
MULTIDISCRETE = 1
|
||||
|
||||
|
||||
class FileSystemState(Enum):
|
||||
"""File System State."""
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
- itemType: ACTIONS
|
||||
type: NODE
|
||||
- itemType: OBSERVATIONS
|
||||
type: MULTIDISCRETE
|
||||
- itemType: STEPS
|
||||
steps: 128
|
||||
- itemType: PORTS
|
||||
|
||||
@@ -23,6 +23,7 @@ from primaite.common.enums import (
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
@@ -148,6 +149,9 @@ class Primaite(Env):
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
|
||||
# Observation type.
|
||||
self.observation_type = 0
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
try:
|
||||
self.config_file = open(self.config_values.config_filename_use_case, "r")
|
||||
@@ -203,26 +207,8 @@ class Primaite(Env):
|
||||
# - service E state | service E loading
|
||||
# - service F state | service F loading
|
||||
# - service G state | service G loading
|
||||
|
||||
# Calculate the number of items that need to be included in the
|
||||
# observation space
|
||||
num_items = self.num_links + self.num_nodes
|
||||
# Set the number of observation parameters, being # of services plus id,
|
||||
# hardware state, file system state and SoftwareState (i.e. 4)
|
||||
self.num_observation_parameters = (
|
||||
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
)
|
||||
# Define the observation shape
|
||||
self.observation_shape = (num_items, self.num_observation_parameters)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=self.config_values.observation_space_high_value,
|
||||
shape=self.observation_shape,
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
# This is the observation that is sent back via the rest and step functions
|
||||
self.env_obs = np.zeros(self.observation_shape, dtype=np.int64)
|
||||
# Initiate observation space
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
if self.action_type == ActionType.NODE:
|
||||
@@ -671,49 +657,172 @@ class Primaite(Env):
|
||||
else:
|
||||
pass
|
||||
|
||||
def init_observations(self):
|
||||
"""Build the observation space based on network laydown and provide initial obs.
|
||||
|
||||
This method uses the object's `num_links`, `num_nodes`, `num_services`,
|
||||
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
|
||||
attributes to figure out the correct shape and format for the observation space.
|
||||
|
||||
Returns
|
||||
-------
|
||||
gym.spaces.Space
|
||||
Gym observation space
|
||||
numpy.Array
|
||||
Initial observation with all entries set to 0
|
||||
"""
|
||||
if self.observation_type == ObservationType.BOX:
|
||||
_LOGGER.info("Observation space type BOX selected")
|
||||
|
||||
# 1. Determine observation shape from laydown
|
||||
num_items = self.num_links + self.num_nodes
|
||||
num_observation_parameters = (
|
||||
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
)
|
||||
observation_shape = (num_items, num_observation_parameters)
|
||||
|
||||
# 2. Create observation space & zeroed out sample from space.
|
||||
observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=self.OBSERVATION_SPACE_HIGH_VALUE,
|
||||
shape=observation_shape,
|
||||
dtype=np.int64,
|
||||
)
|
||||
initial_observation = np.zeros(observation_shape, dtype=np.int64)
|
||||
|
||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
||||
_LOGGER.info("Observation space MULTIDISCRETE selected")
|
||||
|
||||
# 1. Determine observation shape from laydown
|
||||
node_obs_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
node_services = [len(SoftwareState) + 1] * self.num_services
|
||||
node_obs_shape = node_obs_shape + node_services
|
||||
# the magic number 5 refers to 5 states of quantisation of traffic amount.
|
||||
# (zero, low, medium, high, fully utilised/overwhelmed)
|
||||
link_obs_shape = [5] * self.num_links
|
||||
observation_shape = node_obs_shape + link_obs_shape
|
||||
|
||||
# 2. Create observation space & zeroed out sample from space.
|
||||
observation_space = spaces.MultiDiscrete(observation_shape)
|
||||
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
|
||||
f", got {self.observation_type} instead"
|
||||
)
|
||||
|
||||
return observation_space, initial_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
"""Updates the observation space based on the node and link status."""
|
||||
item_index = 0
|
||||
if self.observation_type == ObservationType.BOX:
|
||||
item_index = 0
|
||||
|
||||
# Do nodes first
|
||||
for node_key, node in self.nodes.items():
|
||||
self.env_obs[item_index][0] = int(node.node_id)
|
||||
self.env_obs[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.env_obs[item_index][2] = node.software_state.value
|
||||
self.env_obs[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
# Do nodes first
|
||||
for node_key, node in self.nodes.items():
|
||||
self.env_obs[item_index][0] = int(node.node_id)
|
||||
self.env_obs[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.env_obs[item_index][2] = node.software_state.value
|
||||
self.env_obs[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.services_list:
|
||||
if node.has_service(service):
|
||||
self.env_obs[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
else:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.services_list:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for link_key, link in self.links.items():
|
||||
self.env_obs[item_index][0] = int(link.get_id())
|
||||
self.env_obs[item_index][1] = 0
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.services_list:
|
||||
if node.has_service(service):
|
||||
self.env_obs[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
else:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.services_list:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for link_key, link in self.links.items():
|
||||
self.env_obs[item_index][0] = int(link.get_id())
|
||||
self.env_obs[item_index][1] = 0
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
||||
obs = []
|
||||
# 1. Set nodes
|
||||
# Each node has the following variables in the observation space:
|
||||
# - Hardware state
|
||||
# - Software state
|
||||
# - File System state
|
||||
# - Service 1 state
|
||||
# - Service 2 state
|
||||
# - ...
|
||||
# - Service N state
|
||||
for node_key, node in self.nodes.items():
|
||||
hardware_state = node.hardware_state.value
|
||||
software_state = 0
|
||||
file_system_state = 0
|
||||
services_states = [0] * self.num_services
|
||||
|
||||
if isinstance(
|
||||
node, ActiveNode
|
||||
): # ServiceNode is a subclass of ActiveNode so no need to check that also
|
||||
software_state = node.software_state.value
|
||||
file_system_state = node.file_system_state_observed.value
|
||||
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.services_list):
|
||||
if node.has_service(service):
|
||||
services_states[i] = node.get_service_state(service).value
|
||||
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*services_states,
|
||||
]
|
||||
)
|
||||
|
||||
# 2. Set links
|
||||
# Each link has just one variable in the observation space, it represents the traffic amount
|
||||
# In order for the space to be fully MultiDiscrete, the amount of
|
||||
# traffic on each link is quantised into a few levels:
|
||||
# 0: no traffic (0% of bandwidth)
|
||||
# 1: low traffic (0-33% of bandwidth)
|
||||
# 2: medium traffic (33-66% of bandwidth)
|
||||
# 3: high traffic (66-100% of bandwidth)
|
||||
# 4: max traffic/overloaded (100% of bandwidth)
|
||||
|
||||
for link_key, link in self.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
load = link.get_current_load()
|
||||
|
||||
if load <= 0:
|
||||
traffic_level = 0
|
||||
elif load >= bandwidth:
|
||||
traffic_level = 4
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (1 / 3) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
self.env_obs = np.asarray(obs)
|
||||
|
||||
def load_config(self):
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
@@ -748,6 +857,9 @@ class Primaite(Env):
|
||||
elif item["itemType"] == "ACTIONS":
|
||||
# Get the action information
|
||||
self.get_action_info(item)
|
||||
elif item["itemType"] == "OBSERVATIONS":
|
||||
# Get the observation information
|
||||
self.get_observation_info(item)
|
||||
elif item["itemType"] == "STEPS":
|
||||
# Get the steps information
|
||||
self.get_steps_info(item)
|
||||
@@ -1080,6 +1192,16 @@ class Primaite(Env):
|
||||
"""
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def get_observation_info(self, observation_info):
|
||||
"""Extracts observation_info.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
observation_info : str
|
||||
Config item that defines which type of observation space to use
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
def get_steps_info(self, steps_info):
|
||||
"""
|
||||
Extracts steps_info.
|
||||
|
||||
Reference in New Issue
Block a user