Let single-component spaces not use Tuple Spaces
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
@@ -278,7 +278,7 @@ class ObservationsHandler:
|
||||
"""Initialise the handler without any components yet. They"""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
self.space: spaces.Space
|
||||
self.current_observation: Tuple[np.ndarray]
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# i can access the registry items like this:
|
||||
# self.registry.LINK_TRAFFIC_LEVELS
|
||||
|
||||
@@ -288,7 +288,12 @@ class ObservationsHandler:
|
||||
for obs in self.registered_obs_components:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
self.current_observation = tuple(current_obs)
|
||||
|
||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||
if len(current_obs) == 1:
|
||||
self.current_observation = current_obs[0]
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""TODO: complete description."""
|
||||
@@ -305,7 +310,12 @@ class ObservationsHandler:
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
|
||||
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
||||
if len(component_spaces) == 1:
|
||||
self.space = component_spaces[0]
|
||||
else:
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
|
||||
Reference in New Issue
Block a user