#901 - Removed flatten from training configs

- Added flatten operation in observations.py when there are multiple obs components
- Updated config.rst docs
This commit is contained in:
SunilSamra
2023-07-17 13:44:16 +01:00
parent da20c0e9e6
commit 78d7f39342
4 changed files with 5 additions and 7 deletions

View File

@@ -62,11 +62,11 @@ The environment config file consists of the following attributes:
.. code-block:: yaml
observation_space:
flatten: true
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
options:
combine_service_traffic : False
quantisation_levels: 99
@@ -76,6 +76,7 @@ The environment config file consists of the following attributes:
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)

View File

@@ -606,8 +606,6 @@ class ObservationsHandler:
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
self.flatten: bool = False
def update_obs(self):
"""Fetch fresh information about the environment."""
current_obs = []
@@ -661,7 +659,7 @@ class ObservationsHandler:
@property
def space(self):
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
if len(self.registered_obs_components) > 1:
return self._flat_space
else:
return self._space
@@ -669,7 +667,7 @@ class ObservationsHandler:
@property
def current_observation(self):
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
if len(self.registered_obs_components) > 1:
return self._flat_observation
else:
return self._observation

View File

@@ -54,11 +54,11 @@ hard_coded_agent_view: FULL
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# - name: ACCESS_CONTROL_LIST
# Number of episodes to run per session
num_train_episodes: 10

View File

@@ -54,7 +54,6 @@ hard_coded_agent_view: FULL
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES