Attempt to add flat spaces

This commit is contained in:
Marek Wolan
2023-06-28 11:07:45 +01:00
parent d28db68c02
commit e086d419ad
3 changed files with 33 additions and 6 deletions

6
scratch.py Normal file
View File

@@ -0,0 +1,6 @@
from primaite.main import run
run(
"/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml",
"/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml",
)

View File

@@ -11,12 +11,17 @@ agent_identifier: STABLE_BASELINES3_A2C
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# Number of episodes to run per session
num_episodes: 10
num_episodes: 1000
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
time_delay: 0
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file

View File

@@ -311,8 +311,13 @@ class ObservationsHandler:
def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = []
# need to keep track of the flattened and unflattened version of the space (if there is one)
self.space: spaces.Space
self.unflattened_space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
self.flatten: bool = False
def update_obs(self):
"""Fetch fresh information about the environment."""
@@ -324,9 +329,14 @@ class ObservationsHandler:
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
# If there are many compoenents, the space may need to be flattened
else:
self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
if self.flatten:
self.current_observation = spaces.flatten(
self.unflattened_space, tuple(current_obs)
)
else:
self.current_observation = tuple(current_obs)
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
@@ -357,8 +367,11 @@ class ObservationsHandler:
if len(component_spaces) == 1:
self.space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
self.unflattened_space = spaces.Tuple(component_spaces)
if self.flatten:
self.space = spaces.flatten_space(spaces.Tuple(component_spaces))
else:
self.space = self.unflattened_space
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
@@ -388,6 +401,9 @@ class ObservationsHandler:
# Instantiate the handler
handler = cls()
if obs_space_config.get("flatten"):
handler.flatten = True
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]