901 - fixed test_observation_space.py, added test fixture for test_seeding_and_deterministic_session.py and increased default max number of acls
This commit is contained in:
@@ -91,6 +91,13 @@ session_type: TRAIN_EVAL
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
apply_implicit_rule: False
|
||||
# Implicit ACL firewall rule at end of lists to be default action or no rule can be selected (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
|
||||
@@ -106,7 +106,7 @@ class TrainingConfig:
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 10
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
|
||||
@@ -448,8 +448,8 @@ class AccessControlList(AbstractObservationComponent):
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 1,
|
||||
len(env.ports_list) + 1,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules + 1,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
@@ -523,6 +523,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
# Either do the multiply on the obs space
|
||||
# Change the obs to
|
||||
print("current obs", port_int)
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
|
||||
@@ -62,7 +62,6 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
# shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
@@ -120,6 +119,60 @@ def temp_primaite_session(request):
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_primaite_session_2(request):
|
||||
"""
|
||||
Provides a temporary PrimaiteSession instance.
|
||||
|
||||
It's temporary as it uses a temporary directory as the session path.
|
||||
|
||||
To use this fixture you need to:
|
||||
|
||||
- parametrize your test function with:
|
||||
|
||||
- "temp_primaite_session"
|
||||
- [[path to training config, path to lay down config]]
|
||||
- Include the temp_primaite_session fixture as a param in your test
|
||||
function.
|
||||
- use the temp_primaite_session as a context manager assigning is the
|
||||
name 'session'.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[main_training_config_path(), dos_very_basic_config_path()]
|
||||
],
|
||||
indirect=True
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.learn()
|
||||
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
session.evaluate()
|
||||
|
||||
# To ensure that all files are written, you must call .close()
|
||||
session.close()
|
||||
|
||||
# If you need to inspect any session outputs, it must be done
|
||||
# inside the context manager
|
||||
|
||||
# Now that we've exited the context manager, the
|
||||
# session.session_path directory and its contents are deleted
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_session_path() -> Path:
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,41 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
def run_generic_set_actions(env: Primaite):
|
||||
"""Run against a generic agent with specified blue agent actions."""
|
||||
# Reset the environment at the start of the episode
|
||||
# env.reset()
|
||||
training_config = env.training_config
|
||||
for episode in range(0, training_config.num_train_episodes):
|
||||
for step in range(0, training_config.num_train_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
print("Episode:", episode, "\nStep:", step)
|
||||
if step == 2:
|
||||
# [1, 1, 2, 1, 1, 1, 1(position)]
|
||||
# NEED [1, 1, 1, 2, 1, 1, 1]
|
||||
# Creates an ACL rule
|
||||
# Allows traffic from server_1 to node_1 on port FTP
|
||||
action = 43
|
||||
elif step == 4:
|
||||
action = 96
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
@@ -289,16 +321,23 @@ class TestAccessControlList:
|
||||
# Used to use env from test fixture but AtrributeError function object has no 'training_config'
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
session.learn()
|
||||
env = run_generic_set_actions(env)
|
||||
obs = env.env_obs
|
||||
"""
|
||||
Observation space at the end of the episode.
|
||||
At the start of the episode, there is a single implicit Deny rule = 1,1,1,1,1,0
|
||||
(0 represents its initial position at top of ACL list)
|
||||
(1, 1, 1, 2, 1, 2, 0) - ACTION
|
||||
On Step 5, there is a rule added at POSITION 2: 2,2,3,2,3,0
|
||||
(1, 3, 1, 2, 2, 1) - SECOND ACTION
|
||||
On Step 7, there is a second rule added at POSITION 1: 2,4,2,3,3,1
|
||||
THINK THE RULES SHOULD BE THE OTHER WAY AROUND IN THE CURRENT OBSERVATION
|
||||
"""
|
||||
print("what i am testing", obs)
|
||||
# acl rule 1
|
||||
# source is 1 should be 4
|
||||
# dest is 3 should be 2
|
||||
# [2 2 3 2 3 0 2 1?4 3?2 3 3 1 1 1 1 1 1 2]
|
||||
# np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
assert np.array_equal(obs, [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2])
|
||||
# assert obs == [2, 2, 3, 2, 3, 0, 2, 4, 2, 3, 3, 1, 1, 1, 1, 1, 1, 2]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import time
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
@@ -9,7 +11,12 @@ from tests import TEST_CONFIG_ROOT
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session):
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session_2",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session, temp_primaite_session_2):
|
||||
"""Test running seeded learning produces the same output when ran twice."""
|
||||
"""
|
||||
expected_mean_reward_per_episode = {
|
||||
@@ -31,8 +38,8 @@ def test_seeded_learning(temp_primaite_session):
|
||||
)
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode_run_1 = session.learn_av_reward_per_episode()
|
||||
|
||||
with temp_primaite_session as session:
|
||||
time.sleep(2)
|
||||
with temp_primaite_session_2 as session:
|
||||
assert session._training_config.seed == 67890, (
|
||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user