Better Obs default handling

This commit is contained in:
Marek Wolan
2023-06-01 17:50:18 +01:00
parent e43649a838
commit 0a804e714d

View File

@@ -47,15 +47,12 @@ _LOGGER = logging.getLogger(__name__)
class Primaite(Env):
"""PRIMmary AI Training Evironment (Primaite) class."""
# Observation / Action Space contants
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
# Action Space contants
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
ACTION_SPACE_NODE_ACTION_VALUES = 4
ACTION_SPACE_ACL_ACTION_VALUES = 3
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
def __init__(self, _config_values, _transaction_list):
"""
Init.
@@ -149,8 +146,11 @@ class Primaite(Env):
# The action type
self.action_type = 0
# TODO: proper description here
self.obs_config: dict
# stores the observation config from the yaml, default is NODE_LINK_TABLE
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
# Observation Handler manages the user-configurable observation space.
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
@@ -192,10 +192,6 @@ class Primaite(Env):
_LOGGER.error("Exception occured", exc_info=True)
print("Could not save network diagram")
# # If it doesn't exist after parsing config, create default obs space.
# if getattr(self, "obs_handler", None) is None:
# self.configure_obs_space()
# Initiate observation space
self.observation_space, self.env_obs = self.init_observations()
@@ -648,13 +644,6 @@ class Primaite(Env):
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""TODO: write docstring."""
if getattr(self, "obs_config", None) is None:
self.obs_config = {
"components": [
{"name": "NODE_LINK_TABLE"},
]
}
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
return self.obs_handler.space, self.obs_handler.current_observation