diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5d783af1..9b0bbeec 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -249,11 +249,13 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) print(self.action_space, "ACL action space") - else: - _LOGGER.warning("Action space type ANY selected - Node + ACL") + elif self.action_type == ActionType.ANY: + _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) print(self.action_space, "ANY action space") + else: + _LOGGER.info("Invalid action type selected") # Set up a csv to store the results of the training try: now = datetime.now() # current date and time @@ -1271,6 +1273,5 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} - logging.warning("logging is working") # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 8c87d57b..203a6232 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -9,4 +9,18 @@ def test_single_action_space(): lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - print("Average Reward:", env.average_reward) + """ + nv.action_space.n is the total number of actions in the Discrete action space + This is the number of actions the agent has to choose from. + + The total number of actions that an agent can type when a NODE action type is selected is: 6 + The total number of actions that an agent can take when an ACL action type is selected is: 7 + + These action spaces are combined and the total number of actions is: 12 + This is due to both actions containing the action to "Do nothing", so it needs to be removed from one of the spaces, + to avoid duplicate actions. + + As a result, 12 is the total number of action spaces. + """ + # e + assert env.action_space.n == 12