diff --git a/scratch.py b/scratch.py new file mode 100644 index 00000000..6bab60c1 --- /dev/null +++ b/scratch.py @@ -0,0 +1,6 @@ +from primaite.main import run + +run( + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", +) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..a679400c 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -11,12 +11,17 @@ agent_identifier: STABLE_BASELINES3_A2C # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 10 +time_delay: 0 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..e6eb533c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -311,8 +311,13 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] + + # need to keep track of the flattened and unflattened version of the space (if there is one) self.space: spaces.Space + self.unflattened_space: spaces.Space + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + self.flatten: bool = False def update_obs(self): """Fetch fresh information about the environment.""" @@ -324,9 +329,14 @@ class ObservationsHandler: # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: self.current_observation = current_obs[0] + # If there are many compoenents, the space may need to be flattened else: - self.current_observation = tuple(current_obs) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + if self.flatten: + self.current_observation = spaces.flatten( + self.unflattened_space, tuple(current_obs) + ) + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -357,8 +367,11 @@ class ObservationsHandler: if len(component_spaces) == 1: self.space = component_spaces[0] else: - self.space = spaces.Tuple(component_spaces) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self.unflattened_space = spaces.Tuple(component_spaces) + if self.flatten: + self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) + else: + self.space = self.unflattened_space @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -388,6 +401,9 @@ class ObservationsHandler: # Instantiate the handler handler = cls() + if obs_space_config.get("flatten"): + handler.flatten = True + for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"]