Add tests for observations
This commit is contained in:
@@ -165,14 +165,15 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = [
|
||||
node_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
||||
shape = shape + services_shape
|
||||
node_shape = node_shape + services_shape
|
||||
|
||||
shape = node_shape * self.env.num_nodes
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
@@ -199,7 +200,9 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend([hardware_state, software_state, file_system_state, *service_states])
|
||||
obs.extend(
|
||||
[hardware_state, software_state, file_system_state, *service_states]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
|
||||
@@ -303,8 +306,6 @@ class ObservationsHandler:
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
self.space: spaces.Space
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# i can access the registry items like this:
|
||||
# self.registry.LINK_TRAFFIC_LEVELS
|
||||
|
||||
def update_obs(self):
|
||||
"""Fetch fresh information about the environment."""
|
||||
@@ -318,6 +319,7 @@ class ObservationsHandler:
|
||||
self.current_observation = current_obs[0]
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
@@ -349,6 +351,7 @@ class ObservationsHandler:
|
||||
self.space = component_spaces[0]
|
||||
else:
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
|
||||
Reference in New Issue
Block a user