From 9666b92caa45b96cc007ca5ecf9456e079e12bdc Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 28 Jun 2023 11:07:45 +0100 Subject: [PATCH 1/8] Attempt to add flat spaces --- scratch.py | 6 +++++ .../training/training_config_main.yaml | 9 +++++-- src/primaite/environment/observations.py | 24 +++++++++++++++---- 3 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 scratch.py diff --git a/scratch.py b/scratch.py new file mode 100644 index 00000000..6bab60c1 --- /dev/null +++ b/scratch.py @@ -0,0 +1,6 @@ +from primaite.main import run + +run( + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", +) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..a679400c 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -11,12 +11,17 @@ agent_identifier: STABLE_BASELINES3_A2C # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 10 +time_delay: 0 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..e6eb533c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -311,8 +311,13 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] + + # need to keep track of the flattened and unflattened version of the space (if there is one) self.space: spaces.Space + self.unflattened_space: spaces.Space + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + self.flatten: bool = False def update_obs(self): """Fetch fresh information about the environment.""" @@ -324,9 +329,14 @@ class ObservationsHandler: # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: self.current_observation = current_obs[0] + # If there are many compoenents, the space may need to be flattened else: - self.current_observation = tuple(current_obs) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + if self.flatten: + self.current_observation = spaces.flatten( + self.unflattened_space, tuple(current_obs) + ) + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -357,8 +367,11 @@ class ObservationsHandler: if len(component_spaces) == 1: self.space = component_spaces[0] else: - self.space = spaces.Tuple(component_spaces) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self.unflattened_space = spaces.Tuple(component_spaces) + if self.flatten: + self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) + else: + self.space = self.unflattened_space @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -388,6 +401,9 @@ class ObservationsHandler: # Instantiate the handler handler = cls() + if obs_space_config.get("flatten"): + handler.flatten = True + for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] From c77fde3dd33489647feede79d86f0483cc1259c1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Jun 2023 15:26:07 +0100 Subject: [PATCH 2/8] Fix observation representation in transactions --- src/primaite/environment/observations.py | 149 +++++++++++++++--- src/primaite/environment/primaite_env.py | 5 +- src/primaite/main.py | 1 + .../transactions/transactions_to_file.py | 54 ++----- 4 files changed, 150 insertions(+), 59 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index e6eb533c..023c5f30 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? + self.structure: list[str] return NotImplemented @abstractmethod @@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC): """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented + @abstractmethod + def generate_structure(self) -> List[str]: + """Return a list of labels for the components of the flattened observation space.""" + return NotImplemented + class NodeLinkTable(AbstractObservationComponent): """Table with nodes and links as rows and hardware/software status as cols. @@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent): # 3. Initialise Observation with zeroes self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent): protocol_index += 1 item_index += 1 + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + nodes = self.env.nodes.values() + links = self.env.links.values() + + structure = [] + + for i, node in enumerate(nodes): + node_id = node.node_id + node_labels = [ + f"node_{node_id}_id", + f"node_{node_id}_hardware_status", + f"node_{node_id}_os_status", + f"node_{node_id}_fs_status", + ] + for j, serv in enumerate(self.env.services_list): + node_labels.append(f"node_{node_id}_service_{serv}_status") + + structure.extend(node_labels) + + for i, link in enumerate(links): + link_id = link.id + link_labels = [ + f"link_{link_id}_id", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + ] + for j, serv in enumerate(self.env.services_list): + link_labels.append(f"node_{node_id}_service_{serv}_load") + + structure.extend(link_labels) + return structure + class NodeStatuses(AbstractObservationComponent): """Flat list of nodes' hardware, OS, file system, and service states. @@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() def update(self): """Update the observation based on current environment state. @@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent): ) self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + services = self.env.services_list + + structure = [] + for _, node in self.env.nodes.items(): + node_id = node.node_id + structure.append(f"node_{node_id}_hardware_state_NONE") + for state in HardwareState: + structure.append(f"node_{node_id}_hardware_state_{state.name}") + structure.append(f"node_{node_id}_software_state_NONE") + for state in SoftwareState: + structure.append(f"node_{node_id}_software_state_{state.name}") + structure.append(f"node_{node_id}_file_system_state_NONE") + for state in FileSystemState: + structure.append(f"node_{node_id}_file_system_state_{state.name}") + for service in services: + structure.append(f"node_{node_id}_service_{service}_state_NONE") + for state in SoftwareState: + structure.append( + f"node_{node_id}_service_{service}_state_{state.name}" + ) + return structure + class LinkTrafficLevels(AbstractObservationComponent): """Flat list of traffic levels encoded into banded categories. @@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + structure = [] + for _, link in self.env.links.items(): + link_id = link.id + if self._combine_service_traffic: + protocols = ["overall"] + else: + protocols = [protocol.name for protocol in link.protocol_list] + + for p in protocols: + for i in range(self._quantisation_levels): + structure.append(f"link_{link_id}_{p}_traffic_level_{i}") + return structure + class ObservationsHandler: """Component-based observation space handler. @@ -312,11 +396,15 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] - # need to keep track of the flattened and unflattened version of the space (if there is one) - self.space: spaces.Space - self.unflattened_space: spaces.Space + # internal the observation space (unflattened version of space if flatten=True) + self._space: spaces.Space + # flattened version of the observation space + self._flat_space: spaces.Space + + self._observation: Union[Tuple[np.ndarray], np.ndarray] + # used for transactions and when flatten=true + self._flat_observation: np.ndarray - self.current_observation: Union[Tuple[np.ndarray], np.ndarray] self.flatten: bool = False def update_obs(self): @@ -326,17 +414,11 @@ class ObservationsHandler: obs.update() current_obs.append(obs.current_observation) - # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: - self.current_observation = current_obs[0] - # If there are many compoenents, the space may need to be flattened + self._observation = current_obs[0] else: - if self.flatten: - self.current_observation = spaces.flatten( - self.unflattened_space, tuple(current_obs) - ) - else: - self.current_observation = tuple(current_obs) + self._observation = tuple(current_obs) + self._flat_observation = spaces.flatten(self._space, self._observation) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -363,15 +445,28 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - # If there is only one component, don't use a tuple space, just pass through that component's space. + # if there are multiple components, build a composite tuple space if len(component_spaces) == 1: - self.space = component_spaces[0] + self._space = component_spaces[0] else: - self.unflattened_space = spaces.Tuple(component_spaces) - if self.flatten: - self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) - else: - self.space = self.unflattened_space + self._space = spaces.Tuple(component_spaces) + self._flat_space = spaces.flatten_space(self._space) + + @property + def space(self): + """Observation space, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_space + else: + return self._space + + @property + def current_observation(self): + """Current observation, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_observation + else: + return self._observation @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -417,3 +512,17 @@ class ObservationsHandler: handler.update_obs() return handler + + def describe_structure(self): + """Create a list of names for the features of the obs space. + + The order of labels follows the flattened version of the space. + """ + # as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have + # to fake it. each component has to just hard-code the expected label order after flattening... + + labels = [] + for obs_comp in self.registered_obs_components: + labels.extend(obs_comp.structure) + + return labels diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index be4cc434..e56abf9d 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -318,7 +318,8 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) + transaction.set_obs_space_pre(self.obs_handler._flat_observation) + # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -400,7 +401,7 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() # Load the new observation space into the transaction - transaction.set_obs_space_post(copy.deepcopy(self.env_obs)) + transaction.set_obs_space_post(self.obs_handler._flat_observation) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) diff --git a/src/primaite/main.py b/src/primaite/main.py index f5e94509..4d83f604 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -325,6 +325,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, transaction_list=transaction_list, session_path=session_dir, timestamp_str=timestamp_str, + obs_space_description=env.obs_handler.describe_structure(), ) print("Updating Session Metadata file...") diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 11e68af8..b2a4d40d 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space): return [str(_action_space)] -def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): - """ - Turns observation space into a string array so it can be saved to csv. - - Args: - _obs_space: The observation space - _obs_assets: The number of assets (i.e. nodes or links) in the observation space - _obs_features: The number of features associated with the asset - """ - return_array = [] - for x in range(_obs_assets): - for y in range(_obs_features): - return_array.append(str(_obs_space[x][y])) - - return return_array - - -def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): +def write_transaction_to_file( + transaction_list, + session_path: Path, + timestamp_str: str, + obs_space_description: list, +): """ Writes transaction logs to file to support training evaluation. @@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st # This will be tied into the PrimAITE Use Case so that they make sense template_transation = transaction_list[0] action_length = template_transation.action_space.size - obs_shape = template_transation.obs_space_post.shape - obs_assets = template_transation.obs_space_post.shape[0] - if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - obs_features = 1 - else: - obs_features = template_transation.obs_space_post.shape[1] + # obs_shape = template_transation.obs_space_post.shape + # obs_assets = template_transation.obs_space_post.shape[0] + # if len(obs_shape) == 1: + # bit of a workaround but I think the way transactions are written will change soon + # obs_features = 1 + # else: + # obs_features = template_transation.obs_space_post.shape[1] # Create the action space headers array action_header = [] @@ -70,12 +58,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) + obs_header_initial = [f"pre_{o}" for o in obs_space_description] + obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] @@ -98,12 +82,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + turn_obs_space_to_array( - transaction.obs_space_pre, obs_assets, obs_features - ) - + turn_obs_space_to_array( - transaction.obs_space_post, obs_assets, obs_features - ) + + transaction.obs_space_pre.tolist() + + transaction.obs_space_post.tolist() ) csv_writer.writerow(csv_data) From c3c45125448c905ddc7fffa4921b47187f3bea76 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 09:54:34 +0100 Subject: [PATCH 3/8] Remove temporary file --- scratch.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 scratch.py diff --git a/scratch.py b/scratch.py deleted file mode 100644 index 6bab60c1..00000000 --- a/scratch.py +++ /dev/null @@ -1,6 +0,0 @@ -from primaite.main import run - -run( - "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", - "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", -) From 2a8d28cba68190f3ec528812adaaa09318395f69 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 10:41:56 +0100 Subject: [PATCH 4/8] Remove redundant cols from transactions --- src/primaite/environment/observations.py | 2 +- src/primaite/environment/primaite_env.py | 4 +--- src/primaite/transactions/transaction.py | 13 ++----------- src/primaite/transactions/transactions_to_file.py | 9 ++++----- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 023c5f30..fcd52559 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -168,7 +168,7 @@ class NodeLinkTable(AbstractObservationComponent): f"link_{link_id}_n/a", ] for j, serv in enumerate(self.env.services_list): - link_labels.append(f"node_{node_id}_service_{serv}_load") + link_labels.append(f"link_{link_id}_service_{serv}_load") structure.extend(link_labels) return structure diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e56abf9d..2418cac0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -318,7 +318,7 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(self.obs_handler._flat_observation) + transaction.set_obs_space(self.obs_handler._flat_observation) # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -400,8 +400,6 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() - # Load the new observation space into the transaction - transaction.set_obs_space_post(self.obs_handler._flat_observation) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a4ce48e3..39236217 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -20,23 +20,14 @@ class Transaction(object): self.episode_number = _episode_number self.step_number = _step_number - def set_obs_space_pre(self, _obs_space_pre): + def set_obs_space(self, _obs_space): """ Sets the observation space (pre). Args: _obs_space_pre: The observation space before any actions are taken """ - self.obs_space_pre = _obs_space_pre - - def set_obs_space_post(self, _obs_space_post): - """ - Sets the observation space (post). - - Args: - _obs_space_post: The observation space after any actions are taken - """ - self.obs_space_post = _obs_space_post + self.obs_space = _obs_space def set_reward(self, _reward): """ diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index b2a4d40d..4e364f0b 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -58,12 +58,12 @@ def write_transaction_to_file( action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [f"pre_{o}" for o in obs_space_description] - obs_header_new = [f"post_{o}" for o in obs_space_description] + # obs_header_initial = [f"pre_{o}" for o in obs_space_description] + # obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new + header = header + action_header + obs_space_description try: filename = session_path / f"all_transactions_{timestamp_str}.csv" @@ -82,8 +82,7 @@ def write_transaction_to_file( csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + transaction.obs_space_pre.tolist() - + transaction.obs_space_post.tolist() + + transaction.obs_space.tolist() ) csv_writer.writerow(csv_data) From 32d5889b11e405a0bd5e4ed72eb9727f37aa4652 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 10:44:04 +0100 Subject: [PATCH 5/8] Update docs --- docs/source/primaite_session.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 4f639f11..a59b2361 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session: * Timestamp * Episode number * Step number - * Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y - * Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y + * Initial observation space (what the blue agent observed when it decided its action) * Reward value - * Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X + * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X **Diagrams** From 975ebd6de2d43aa0ad65d1d69cf45c37c81aa609 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 13:16:30 +0100 Subject: [PATCH 6/8] revert unnecessary changes. --- .../_package_data/training/training_config_main.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a679400c..ac63c667 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -5,7 +5,7 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +agent_identifier: STABLE_BASELINES3_PPO # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -16,12 +16,14 @@ observation_space: # flatten: true components: - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session -num_episodes: 1000 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 0 +time_delay: 10 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file From 605ff98a24eaca0e34b3d4c24e0dc8b4fc42761b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 15:43:15 +0100 Subject: [PATCH 7/8] Fix flattening when there are no components. --- src/primaite/environment/observations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index fcd52559..b19bd29f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -450,7 +450,10 @@ class ObservationsHandler: self._space = component_spaces[0] else: self._space = spaces.Tuple(component_spaces) - self._flat_space = spaces.flatten_space(self._space) + if len(component_spaces) > 0: + self._flat_space = spaces.flatten_space(self._space) + else: + self._flat_space = spaces.Box(0, 1, (0,)) @property def space(self): From ee94993344d8975336859845ed06364ff80c9a4e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 08:00:51 +0000 Subject: [PATCH 8/8] Apply suggestions from code review --- src/primaite/environment/observations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index b19bd29f..81ddaaf5 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,7 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? - self.structure: list[str] + self.structure: List[str] return NotImplemented @abstractmethod