Modify tests based on refactoring
This commit is contained in:
@@ -101,11 +101,11 @@ class PrimaiteSession:
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
|
||||
sess.env = PrimaiteRayEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
|
||||
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game=game)
|
||||
sess.env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
|
||||
@@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
|
||||
|
||||
return state
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def enable(self):
|
||||
"""
|
||||
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
|
||||
|
||||
@@ -213,11 +213,6 @@ class NIC(IPWiredNetworkInterface):
|
||||
|
||||
return state
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Attempt to receive and process a network frame from the connected Link.
|
||||
|
||||
@@ -109,24 +109,6 @@ class Firewall(Router):
|
||||
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
|
||||
)
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
def set_original_state(self):
|
||||
"""Set the original state for the Firewall."""
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"internal_port",
|
||||
"external_port",
|
||||
"dmz_port",
|
||||
"internal_inbound_acl",
|
||||
"internal_outbound_acl",
|
||||
"dmz_inbound_acl",
|
||||
"dmz_outbound_acl",
|
||||
"external_inbound_acl",
|
||||
"external_outbound_acl",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the Firewall.
|
||||
|
||||
@@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface):
|
||||
_connected_node: Optional[Switch] = None
|
||||
"The Switch to which the SwitchPort is connected."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
|
||||
self._original_state = self.model_dump(include=vals_to_include)
|
||||
super().set_original_state()
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -122,8 +122,6 @@ class WirelessRouter(Router):
|
||||
|
||||
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
|
||||
|
||||
self.set_original_state()
|
||||
|
||||
@property
|
||||
def wireless_access_point(self) -> WirelessAccessPoint:
|
||||
"""
|
||||
@@ -166,7 +164,6 @@ class WirelessRouter(Router):
|
||||
network_interface.ip_address = ip_address
|
||||
network_interface.subnet_mask = subnet_mask
|
||||
self.sys_log.info(f"Configured WAP {network_interface}")
|
||||
self.set_original_state()
|
||||
self.wireless_access_point.frequency = frequency # Set operating frequency
|
||||
self.wireless_access_point.enable() # Re-enable the WAP with new settings
|
||||
|
||||
|
||||
@@ -589,15 +589,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
|
||||
@@ -593,15 +593,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -624,7 +625,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -1043,16 +1043,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
@@ -1074,7 +1074,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -599,15 +599,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -630,7 +631,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -600,15 +600,16 @@ simulation:
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
@@ -631,7 +632,7 @@ simulation:
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
|
||||
@@ -17,8 +17,7 @@ def test_sb3_compatibility():
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
gym = PrimaiteGymEnv(game=game)
|
||||
gym = PrimaiteGymEnv(game_config=cfg)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
@@ -12,20 +12,6 @@ def account() -> Account:
|
||||
|
||||
def test_original_state(account):
|
||||
"""Test the original state - see if it resets properly"""
|
||||
account.log_on()
|
||||
account.log_off()
|
||||
account.disable()
|
||||
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is False
|
||||
|
||||
account.reset_component_for_episode(episode=1)
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 0
|
||||
assert state["num_logoffs"] is 0
|
||||
@@ -39,11 +25,6 @@ def test_original_state(account):
|
||||
account.log_off()
|
||||
account.disable()
|
||||
|
||||
account.log_on()
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 2
|
||||
|
||||
account.reset_component_for_episode(episode=2)
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
|
||||
@@ -185,37 +185,6 @@ def test_get_file(file_system):
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
def test_reset_file_system(file_system):
|
||||
# file and folder that existed originally
|
||||
file_system.create_file(file_name="test_file.zip")
|
||||
file_system.create_folder(folder_name="test_folder")
|
||||
|
||||
# create a new file
|
||||
file_system.create_file(file_name="new_file.txt")
|
||||
|
||||
# create a new folder
|
||||
file_system.create_folder(folder_name="new_folder")
|
||||
|
||||
# delete the file that existed originally
|
||||
file_system.delete_file(folder_name="root", file_name="test_file.zip")
|
||||
assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
|
||||
|
||||
# delete the folder that existed originally
|
||||
file_system.delete_folder(folder_name="test_folder")
|
||||
assert file_system.get_folder(folder_name="test_folder") is None
|
||||
|
||||
# reset
|
||||
file_system.reset_component_for_episode(episode=1)
|
||||
|
||||
# deleted original file and folder should be back
|
||||
assert file_system.get_file(folder_name="root", file_name="test_file.zip")
|
||||
assert file_system.get_folder(folder_name="test_folder")
|
||||
|
||||
# new file and folder should be removed
|
||||
assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
|
||||
assert file_system.get_folder(folder_name="new_folder") is None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
|
||||
def test_serialisation(file_system):
|
||||
"""Test to check that the object serialisation works correctly."""
|
||||
|
||||
@@ -44,40 +44,6 @@ def test_describe_state(network):
|
||||
assert len(state["links"]) is 6
|
||||
|
||||
|
||||
def test_reset_network(network):
|
||||
"""
|
||||
Test that the network is properly reset.
|
||||
|
||||
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
|
||||
etc are also removed/reinstalled
|
||||
|
||||
"""
|
||||
state_before = network.describe_state()
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
server_1: Computer = network.get_node_by_hostname("server_1")
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
client_1.power_off()
|
||||
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
server_1.power_off()
|
||||
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
assert network.describe_state() != state_before
|
||||
|
||||
network.reset_component_for_episode(episode=1)
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
# don't worry if UUIDs change
|
||||
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
|
||||
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_creating_container():
|
||||
"""Check that we can create a network container"""
|
||||
net = Network()
|
||||
|
||||
@@ -27,34 +27,6 @@ def test_dos_bot_creation(dos_bot):
|
||||
assert dos_bot is not None
|
||||
|
||||
|
||||
def test_dos_bot_reset(dos_bot):
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
|
||||
# should reset the relevant items
|
||||
dos_bot.reset_component_for_episode(episode=0)
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
dos_bot.reset_component_for_episode(episode=1)
|
||||
# should reset to the configured value
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
|
||||
assert dos_bot.target_port is Port.HTTP
|
||||
assert dos_bot.payload == "payload"
|
||||
assert dos_bot.repeat is True
|
||||
|
||||
|
||||
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
|
||||
dos_bot_node: Computer = dos_bot.parent
|
||||
assert dos_bot_node.operating_state is NodeOperatingState.ON
|
||||
|
||||
Reference in New Issue
Block a user