diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a1b0d9ac..5bad056c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Tuple, Union import numpy as np from gym import spaces @@ -278,7 +278,7 @@ class ObservationsHandler: """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space - self.current_observation: Tuple[np.ndarray] + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] # i can access the registry items like this: # self.registry.LINK_TRAFFIC_LEVELS @@ -288,7 +288,12 @@ class ObservationsHandler: for obs in self.registered_obs_components: obs.update() current_obs.append(obs.current_observation) - self.current_observation = tuple(current_obs) + + # If there is only one component, don't use a tuple, just pass through that component's obs. + if len(current_obs) == 1: + self.current_observation = current_obs[0] + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """TODO: complete description.""" @@ -305,7 +310,12 @@ class ObservationsHandler: component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - self.space = spaces.Tuple(component_spaces) + + # If there is only one component, don't use a tuple space, just pass through that component's space. + if len(component_spaces) == 1: + self.space = component_spaces[0] + else: + self.space = spaces.Tuple(component_spaces) @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict):