From c0b214612a6cf4b7290325a55292402544634aa7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 18:01:47 +0100 Subject: [PATCH] Let single-component spaces not use Tuple Spaces --- src/primaite/environment/observations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a1b0d9ac..5bad056c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Tuple, Union import numpy as np from gym import spaces @@ -278,7 +278,7 @@ class ObservationsHandler: """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space - self.current_observation: Tuple[np.ndarray] + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] # i can access the registry items like this: # self.registry.LINK_TRAFFIC_LEVELS @@ -288,7 +288,12 @@ class ObservationsHandler: for obs in self.registered_obs_components: obs.update() current_obs.append(obs.current_observation) - self.current_observation = tuple(current_obs) + + # If there is only one component, don't use a tuple, just pass through that component's obs. + if len(current_obs) == 1: + self.current_observation = current_obs[0] + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """TODO: complete description.""" @@ -305,7 +310,12 @@ class ObservationsHandler: component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - self.space = spaces.Tuple(component_spaces) + + # If there is only one component, don't use a tuple space, just pass through that component's space. + if len(component_spaces) == 1: + self.space = component_spaces[0] + else: + self.space = spaces.Tuple(component_spaces) @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict):