Add tests for observations

This commit is contained in:
Marek Wolan
2023-06-02 12:59:01 +01:00
parent b6ce1cbae9
commit f37b943f7e
8 changed files with 476 additions and 32 deletions

View File

@@ -165,14 +165,15 @@ class NodeStatuses(AbstractObservationComponent):
super().__init__(env)
# 1. Define the shape of your observation space component
shape = [
node_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
shape = shape + services_shape
node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
@@ -199,7 +200,9 @@ class NodeStatuses(AbstractObservationComponent):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend([hardware_state, software_state, file_system_state, *service_states])
obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
)
self.current_observation[:] = obs
@@ -303,8 +306,6 @@ class ObservationsHandler:
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
# i can access the registry items like this:
# self.registry.LINK_TRAFFIC_LEVELS
def update_obs(self):
"""Fetch fresh information about the environment."""
@@ -318,6 +319,7 @@ class ObservationsHandler:
self.current_observation = current_obs[0]
else:
self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
@@ -349,6 +351,7 @@ class ObservationsHandler:
self.space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):