901 - removed print statements and merged with dev
This commit is contained in:
@@ -18,7 +18,6 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
|
||||
def _get_primaite_config():
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
print("config path", config_path)
|
||||
if not config_path.exists():
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
with open(config_path, "r") as file:
|
||||
|
||||
@@ -51,14 +51,15 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
|
||||
@@ -12,10 +12,10 @@ agent_identifier: PPO
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 1
|
||||
# Number of time_steps per episode
|
||||
num_steps: 5
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 1
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 5
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
apply_implicit_rule: True
|
||||
|
||||
@@ -62,7 +62,7 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
# shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
|
||||
@@ -5,20 +5,6 @@ import pytest
|
||||
|
||||
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env(request):
|
||||
"""Build Primaite environment for integration tests of observation space."""
|
||||
marker = request.node.get_closest_marker("env_config_paths")
|
||||
training_config_path = marker.args[0]["training_config_path"]
|
||||
lay_down_config_path = marker.args[0]["lay_down_config_path"]
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
)
|
||||
yield env
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -314,5 +300,5 @@ class TestAccessControlList:
|
||||
THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION
|
||||
"""
|
||||
# np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
# assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
assert obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
# assert obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
Reference in New Issue
Block a user