Let single-component spaces not use Tuple Spaces

This commit is contained in:
Marek Wolan
2023-06-01 18:01:47 +01:00
parent 3e208bad9b
commit c0b214612a

View File

@@ -1,7 +1,7 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Tuple, Union
import numpy as np
from gym import spaces
@@ -278,7 +278,7 @@ class ObservationsHandler:
"""Initialise the handler without any components yet. They"""
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Tuple[np.ndarray]
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
# i can access the registry items like this:
# self.registry.LINK_TRAFFIC_LEVELS
@@ -288,7 +288,12 @@ class ObservationsHandler:
for obs in self.registered_obs_components:
obs.update()
current_obs.append(obs.current_observation)
self.current_observation = tuple(current_obs)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
else:
self.current_observation = tuple(current_obs)
def register(self, obs_component: AbstractObservationComponent):
"""TODO: complete description."""
@@ -305,7 +310,12 @@ class ObservationsHandler:
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
self.space = spaces.Tuple(component_spaces)
# If there is only one component, don't use a tuple space, just pass through that component's space.
if len(component_spaces) == 1:
self.space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):