Attempt to add flat spaces
This commit is contained in:
6
scratch.py
Normal file
6
scratch.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from primaite.main import run
|
||||
|
||||
run(
|
||||
"/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml",
|
||||
"/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml",
|
||||
)
|
||||
@@ -11,12 +11,17 @@ agent_identifier: STABLE_BASELINES3_A2C
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
num_episodes: 1000
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
time_delay: 0
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
|
||||
@@ -311,8 +311,13 @@ class ObservationsHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
# need to keep track of the flattened and unflattened version of the space (if there is one)
|
||||
self.space: spaces.Space
|
||||
self.unflattened_space: spaces.Space
|
||||
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self):
|
||||
"""Fetch fresh information about the environment."""
|
||||
@@ -324,9 +329,14 @@ class ObservationsHandler:
|
||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||
if len(current_obs) == 1:
|
||||
self.current_observation = current_obs[0]
|
||||
# If there are many compoenents, the space may need to be flattened
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
if self.flatten:
|
||||
self.current_observation = spaces.flatten(
|
||||
self.unflattened_space, tuple(current_obs)
|
||||
)
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
@@ -357,8 +367,11 @@ class ObservationsHandler:
|
||||
if len(component_spaces) == 1:
|
||||
self.space = component_spaces[0]
|
||||
else:
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
self.unflattened_space = spaces.Tuple(component_spaces)
|
||||
if self.flatten:
|
||||
self.space = spaces.flatten_space(spaces.Tuple(component_spaces))
|
||||
else:
|
||||
self.space = self.unflattened_space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
@@ -388,6 +401,9 @@ class ObservationsHandler:
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
if obs_space_config.get("flatten"):
|
||||
handler.flatten = True
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
|
||||
Reference in New Issue
Block a user