From 3e208bad9bf553acc2eb4157d4bd07da906d5227 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 17:50:18 +0100 Subject: [PATCH] Better Obs default handling --- src/primaite/environment/primaite_env.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0107920f..81557075 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -47,15 +47,12 @@ _LOGGER = logging.getLogger(__name__) class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" - # Observation / Action Space contants - OBSERVATION_SPACE_FIXED_PARAMETERS = 4 + # Action Space contants ACTION_SPACE_NODE_PROPERTY_VALUES = 5 ACTION_SPACE_NODE_ACTION_VALUES = 4 ACTION_SPACE_ACL_ACTION_VALUES = 3 ACTION_SPACE_ACL_PERMISSION_VALUES = 2 - OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space - def __init__(self, _config_values, _transaction_list): """ Init. @@ -149,8 +146,11 @@ class Primaite(Env): # The action type self.action_type = 0 - # TODO: proper description here - self.obs_config: dict + # stores the observation config from the yaml, default is NODE_LINK_TABLE + self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} + + # Observation Handler manages the user-configurable observation space. + # It will be initialised later. self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -192,10 +192,6 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") - # # If it doesn't exist after parsing config, create default obs space. - # if getattr(self, "obs_handler", None) is None: - # self.configure_obs_space() - # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -648,13 +644,6 @@ class Primaite(Env): def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: """TODO: write docstring.""" - if getattr(self, "obs_config", None) is None: - self.obs_config = { - "components": [ - {"name": "NODE_LINK_TABLE"}, - ] - } - self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation