From 8008fab523df9a9309733a955c760982619bf607 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Mon, 17 Jul 2023 13:44:16 +0100 Subject: [PATCH] #901 - Removed flatten from training configs - Added flatten operation in observations.py when there are multiple obs components - Updated config.rst docs --- docs/source/config.rst | 3 ++- src/primaite/environment/observations.py | 6 ++---- tests/config/ppo_not_seeded_training_config.yaml | 2 +- tests/config/ppo_seeded_training_config.yaml | 1 - 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/source/config.rst b/docs/source/config.rst index 8367faf0..16740f1b 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -62,11 +62,11 @@ The environment config file consists of the following attributes: .. code-block:: yaml observation_space: - flatten: true components: - name: NODE_LINK_TABLE - name: NODE_STATUSES - name: LINK_TRAFFIC_LEVELS + - name: ACCESS_CONTROL_LIST options: combine_service_traffic : False quantisation_levels: 99 @@ -76,6 +76,7 @@ The environment config file consists of the following attributes: * :py:mod:`NODE_LINK_TABLE` this does not accept any additional options * :py:mod:`NODE_STATUSES`, this does not accept any additional options + * :py:mod:`ACCESS_CONTROL_LIST`, this does not accept additional options * :py:mod:`LINK_TRAFFIC_LEVELS`, this accepts the following options: * ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index bb5ec62c..70f3cdde 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -606,8 +606,6 @@ class ObservationsHandler: # used for transactions and when flatten=true self._flat_observation: np.ndarray - self.flatten: bool = False - def update_obs(self): """Fetch fresh information about the environment.""" current_obs = [] @@ -661,7 +659,7 @@ class ObservationsHandler: @property def space(self): """Observation space, return the flattened version if flatten is True.""" - if self.flatten: + if len(self.registered_obs_components) > 1: return self._flat_space else: return self._space @@ -669,7 +667,7 @@ class ObservationsHandler: @property def current_observation(self): """Current observation, return the flattened version if flatten is True.""" - if self.flatten: + if len(self.registered_obs_components) > 1: return self._flat_observation else: return self._observation diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml index 3d638ac6..ef23d432 100644 --- a/tests/config/ppo_not_seeded_training_config.yaml +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -54,11 +54,11 @@ hard_coded_agent_view: FULL action_type: NODE # observation space observation_space: - # flatten: true components: - name: NODE_LINK_TABLE # - name: NODE_STATUSES # - name: LINK_TRAFFIC_LEVELS + # - name: ACCESS_CONTROL_LIST # Number of episodes to run per session num_train_episodes: 10 diff --git a/tests/config/ppo_seeded_training_config.yaml b/tests/config/ppo_seeded_training_config.yaml index 86abcae7..2c7c117c 100644 --- a/tests/config/ppo_seeded_training_config.yaml +++ b/tests/config/ppo_seeded_training_config.yaml @@ -54,7 +54,6 @@ hard_coded_agent_view: FULL action_type: NODE # observation space observation_space: - # flatten: true components: - name: NODE_LINK_TABLE # - name: NODE_STATUSES