Merge remote-tracking branch 'origin/dev' into feature/2257-router-routes-cannot-be-represented-in-config-file

This commit is contained in:
Czar Echavez
2024-02-29 11:34:56 +00:00
58 changed files with 548 additions and 963 deletions

View File

@@ -14,7 +14,7 @@ io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_pcap_logs: false
save_sys_logs: true
@@ -656,12 +656,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server

View File

@@ -1,15 +1,23 @@
training_config:
rl_framework: RLLIB_multi_agent
# rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 256
deterministic_eval: false
n_agents: 2
agent_references:
- defender_1
- defender_2
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true
game:
@@ -36,9 +44,9 @@ agents:
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
- node_name: client_2
applications:
- application_ref: client_2_web_browser
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -54,6 +62,31 @@ agents:
frequency: 4
variance: 3
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
@@ -72,7 +105,7 @@ agents:
- type: NODE_OS_SCAN
options:
nodes:
- node_ref: client_1
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
@@ -104,25 +137,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,23 +166,23 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
- node_hostname: domain_controller
nic_num: 1
- node_ref: web_server
- node_hostname: web_server
nic_num: 1
- node_ref: database_server
- node_hostname: database_server
nic_num: 1
- node_ref: backup_server
- node_hostname: backup_server
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 1
- node_ref: client_1
- node_hostname: client_1
nic_num: 1
- node_ref: client_2
- node_hostname: client_2
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 2
ics: null
@@ -184,10 +213,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -242,25 +271,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -271,22 +300,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -303,63 +332,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -407,123 +436,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
options:
nodes:
- node_ref: domain_controller
- node_ref: web_server
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_ref: database_service
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
weight: 0.34
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_hostname: client_2
agent_settings:
# ...
flatten_obs: true
- ref: defender_2
team: BLUE
@@ -537,25 +591,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -570,23 +620,23 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
- node_hostname: domain_controller
nic_num: 1
- node_ref: web_server
- node_hostname: web_server
nic_num: 1
- node_ref: database_server
- node_hostname: database_server
nic_num: 1
- node_ref: backup_server
- node_hostname: backup_server
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 1
- node_ref: client_1
- node_hostname: client_1
nic_num: 1
- node_ref: client_2
- node_hostname: client_2
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 2
ics: null
@@ -617,10 +667,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -675,25 +725,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -704,22 +754,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -736,63 +786,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -840,122 +890,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
options:
nodes:
- node_ref: domain_controller
- node_ref: web_server
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_ref: database_service
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
weight: 0.34
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_hostname: client_2
agent_settings:
# ...
flatten_obs: true
@@ -1036,12 +1112,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
@@ -1093,10 +1170,14 @@ simulation:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_1_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_1_dns_client
type: DNSClient
@@ -1113,6 +1194,13 @@ simulation:
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient

View File

@@ -83,18 +83,15 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.agents: Dict[str, AbstractAgent] = {}
"""Mapping from agent name to agent object."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.rl_agents: Dict[str, ProxyAgent] = {}
"""Subset of agents which are intended for reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
@@ -140,7 +137,7 @@ class PrimaiteGame:
self.update_agents(sim_state)
# Apply all actions to simulation as requests
agent_actions = self.apply_agent_actions() # noqa
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
@@ -151,7 +148,7 @@ class PrimaiteGame:
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
for _, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
agent.reward_function.total_reward += agent.reward_function.current_reward
@@ -165,7 +162,7 @@ class PrimaiteGame:
"""
agent_actions = {}
for agent in self.agents:
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
@@ -188,20 +185,14 @@ class PrimaiteGame:
return True
return False
def reset(self) -> None:
"""Reset the game, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(episode=self.episode_counter)
for agent in self.agents:
agent.reward_function.total_reward = 0.0
agent.reset_agent_for_episode()
def close(self) -> None:
"""Close the game, this will close the simulation."""
return NotImplemented
def setup_for_episode(self, episode: int) -> None:
"""Perform any final configuration of components to make them ready for the game to start."""
self.simulation.setup_for_episode(episode=episode)
@classmethod
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
@@ -280,8 +271,9 @@ class PrimaiteGame:
# start the service
new_service.start()
else:
_LOGGER.warning(f"service type not found {service_type}")
msg = f"Configuration contains an invalid service type: {service_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# service-dependent options
if service_type == "DNSClient":
if "options" in service_cfg:
@@ -318,7 +310,9 @@ class PrimaiteGame:
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application.uuid
else:
_LOGGER.warning(f"application type not found {application_type}")
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# run the application
new_application.run()
@@ -419,7 +413,6 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -428,8 +421,7 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -438,11 +430,11 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
else:
_LOGGER.warning(f"agent type {agent_type} not found")
game.simulation.set_original_state()
msg(f"Configuration error: {agent_type} is not a valid agent type.")
_LOGGER.error(msg)
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
# Set the NMNE capture config
set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {}))

View File

@@ -60,7 +60,7 @@
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
@@ -88,6 +88,13 @@
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -54,7 +54,7 @@
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"cfg\":cfg}\n",
"env_config = cfg\n",
"\n",
"config = (\n",
" PPOConfig()\n",

View File

@@ -27,9 +27,7 @@
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
" cfg = yaml.safe_load(f)\n"
]
},
{
@@ -38,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game=game)"
"gym = PrimaiteGymEnv(game_config=cfg)"
]
},
{
@@ -65,7 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=1000)\n"
"model.learn(total_timesteps=10)\n"
]
},
{
@@ -76,6 +74,13 @@
"source": [
"model.save(\"deleteme\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -346,9 +346,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
@@ -358,9 +356,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
@@ -383,9 +379,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"# create the env\n",
@@ -396,10 +390,10 @@
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
" cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n",
"game = PrimaiteGame.from_config(cfg)\n",
"env = PrimaiteGymEnv(game = game)\n",
"# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n",
"env.agent.flatten_obs = False\n",
" # don't flatten observations so that we can see what is going on\n",
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",
"obs, info = env.reset()\n",
"print('env created successfully')\n",
"pprint(obs)"
@@ -433,9 +427,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"for step in range(35):\n",
@@ -453,9 +445,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"pprint(obs['NODES'])"
@@ -471,9 +461,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(9) # scan database file\n",
@@ -499,9 +487,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
@@ -526,9 +512,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
@@ -551,9 +535,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"metadata": {},
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
@@ -593,6 +575,22 @@
"obs['ACL']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -603,7 +601,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "venv",
"language": "python",
"name": "python3"
},
@@ -617,9 +615,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 2
}

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -18,11 +18,23 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game: PrimaiteGame):
def __init__(self, game_config: Dict):
"""Initialise the environment."""
super().__init__()
self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = self.game.rl_agents[0]
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
return self.game.rl_agents[self._agent_name]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
@@ -45,13 +57,13 @@ class PrimaiteGymEnv(gymnasium.Env):
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"episode": self.episode_counter,
"step": self.game.step_counter,
"action": int(action),
"reward": int(reward),
@@ -63,10 +75,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
print(
f"Resetting environment, episode {self.game.episode_counter}, "
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
self.game.reset()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -88,12 +102,12 @@ class PrimaiteGymEnv(gymnasium.Env):
def _get_obs(self) -> ObsType:
"""Return the current observation."""
if not self.agent.flatten_obs:
return self.agent.observation_manager.current_observation
else:
if self.agent.flatten_obs:
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
else:
return self.agent.observation_manager.current_observation
class PrimaiteRayEnv(gymnasium.Env):
@@ -102,12 +116,11 @@ class PrimaiteRayEnv(gymnasium.Env):
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
self.env.game.episode_counter -= 1
self.env = PrimaiteGymEnv(game_config=env_config)
self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -128,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
:type env_config: Dict
"""
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
self._agent_ids = list(self.agents.keys())
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.episode_counter: int = 0
"""Current episode number."""
self.terminateds = set()
self.truncateds = set()
@@ -149,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
)
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game.reset()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -172,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -186,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {"agent_actions": agent_actions}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
@@ -194,13 +217,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"episode": self.episode_counter,
"step": self.game.step_counter,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
@@ -212,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for name, agent in self.agents.items():
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs

View File

@@ -101,11 +101,11 @@ class PrimaiteSession:
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"cfg": cfg})
sess.env = PrimaiteRayEnv(env_config=cfg)
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg})
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game=game)
sess.env = PrimaiteGymEnv(game_config=cfg)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:

View File

@@ -153,22 +153,18 @@ class SimComponent(BaseModel):
uuid: str = Field(default_factory=lambda: str(uuid4()))
"""The component UUID."""
_original_state: Dict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
# @abstractmethod
def set_original_state(self):
"""Sets the original state."""
pass
def setup_for_episode(self, episode: int):
"""
Perform any additional setup on this component that can't happen during __init__.
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for key, value in self._original_state.items():
self.__setattr__(key, value)
For instance, some components may require for the entire network to exist before some configuration can be set.
"""
pass
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -42,19 +42,6 @@ class Account(SimComponent):
"Account Type, currently this can be service account (used by apps) or user account."
enabled: bool = True
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"num_logons",
"num_logoffs",
"num_group_changes",
"username",
"password",
"account_type",
"enabled",
}
self._original_state = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -73,20 +73,6 @@ class File(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}")
super().set_original_state()
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
super().reset_component_for_episode(episode)
@property
def path(self) -> str:
"""

View File

@@ -34,43 +34,6 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}")
for folder in self.folders.values():
folder.set_original_state()
# Capture a list of all 'original' file uuids
original_keys = list(self.folders.keys())
vals_to_include = {"sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_folder_uuids"] = original_keys
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
# Move any 'original' folder that have been deleted back to folders
original_folder_uuids = self._original_state["original_folder_uuids"]
for uuid in original_folder_uuids:
if uuid in self.deleted_folders:
folder = self.deleted_folders[uuid]
self.deleted_folders.pop(uuid)
self.folders[uuid] = folder
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
# Now clear all non-original folders created by agent
current_folder_uuids = list(self.folders.keys())
for uuid in current_folder_uuids:
if uuid not in original_folder_uuids:
folder = self.folders[uuid]
self.folders.pop(uuid)
# Now reset all remaining folders
for folder in self.folders.values():
folder.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -85,11 +85,6 @@ class FileSystemItemABC(SimComponent):
deleted: bool = False
"If true, the FileSystemItem was deleted."
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -49,49 +49,6 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}")
for file in self.files.values():
file.set_original_state()
super().set_original_state()
vals_to_include = {
"scan_duration",
"scan_countdown",
"red_scan_duration",
"red_scan_countdown",
"restore_duration",
"restore_countdown",
}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_file_uuids"] = list(self.files.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
# Move any 'original' file that have been deleted back to files
original_file_uuids = self._original_state["original_file_uuids"]
for uuid in original_file_uuids:
if uuid in self.deleted_files:
file = self.deleted_files[uuid]
self.deleted_files.pop(uuid)
self.files[uuid] = file
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
# Now clear all non-original files created by agent
current_file_uuids = list(self.files.keys())
for uuid in current_file_uuids:
if uuid not in original_file_uuids:
file = self.files[uuid]
self.files.pop(uuid)
# Now reset all remaining files
for file in self.files.values():
file.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(

View File

@@ -273,11 +273,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
return state
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def enable(self):
"""
Enables this wired network interface and attempts to send a "hello" message to the default gateway.

View File

@@ -45,19 +45,12 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
def set_original_state(self):
"""Sets the original state."""
for node in self.nodes.values():
node.set_original_state()
for link in self.links.values():
link.set_original_state()
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.reset_component_for_episode(episode)
node.setup_for_episode(episode=episode)
for link in self.links.values():
link.reset_component_for_episode(episode)
link.setup_for_episode(episode=episode)
for node in self.nodes.values():
node.power_on()
@@ -179,7 +172,7 @@ class Network(SimComponent):
def clear_links(self):
"""Clear all the links in the network by resetting their component state for the episode."""
for link in self.links.values():
link.reset_component_for_episode()
link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here.
def draw(self, seed: int = 123):
"""

View File

@@ -100,6 +100,15 @@ class NetworkInterface(SimComponent, ABC):
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)
self.nmne = {}
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -127,15 +136,6 @@ class NetworkInterface(SimComponent, ABC):
state.update({"nmne": self.nmne})
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
self.nmne = {}
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
@abstractmethod
def enable(self):
"""Enable the interface."""
@@ -547,14 +547,6 @@ class Link(SimComponent):
self.endpoint_b.connect_link(self)
self.endpoint_up()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"bandwidth", "current_load"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -740,50 +732,20 @@ class Node(SimComponent):
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self._install_system_software()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
for software in self.software_manager.software.values():
software.set_original_state()
self.file_system.set_original_state()
for network_interface in self.network_interfaces.values():
network_interface.set_original_state()
vals_to_include = {
"hostname",
"default_gateway",
"operating_state",
"revealed_to_red",
"start_up_duration",
"start_up_countdown",
"shut_down_duration",
"shut_down_countdown",
"is_resetting",
"node_scan_duration",
"node_scan_countdown",
"red_scan_countdown",
}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
# Reset Session Manager
self.session_manager.clear()
super().setup_for_episode(episode=episode)
# Reset File System
self.file_system.reset_component_for_episode(episode)
self.file_system.setup_for_episode(episode=episode)
# Reset all Nics
for network_interface in self.network_interfaces.values():
network_interface.reset_component_for_episode(episode)
network_interface.setup_for_episode(episode=episode)
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
software.setup_for_episode(episode=episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode

View File

@@ -209,11 +209,6 @@ class NIC(IPWiredNetworkInterface):
return state
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def receive_frame(self, frame: Frame) -> bool:
"""
Attempt to receive and process a network frame from the connected Link.

View File

@@ -111,24 +111,6 @@ class Firewall(Router):
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
)
self.set_original_state()
def set_original_state(self):
"""Set the original state for the Firewall."""
super().set_original_state()
vals_to_include = {
"internal_port",
"external_port",
"dmz_port",
"internal_inbound_acl",
"internal_outbound_acl",
"dmz_inbound_acl",
"dmz_outbound_acl",
"external_inbound_acl",
"external_outbound_acl",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def describe_state(self) -> Dict:
"""
Describes the current state of the Firewall.

View File

@@ -136,11 +136,6 @@ class ACLRule(SimComponent):
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -296,48 +291,6 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.implicit_rule.set_original_state()
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
for i, rule in enumerate(self._acl):
if not rule:
continue
self._default_config[i] = {"action": rule.action.name}
if rule.src_ip_address:
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
if rule.dst_ip_address:
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
if rule.src_port:
self._default_config[i]["src_port"] = rule.src_port.name
if rule.dst_port:
self._default_config[i]["dst_port"] = rule.dst_port.name
if rule.protocol:
self._default_config[i]["protocol"] = rule.protocol.name
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
self._reset_rules_to_default()
def _reset_rules_to_default(self) -> None:
"""Clear all ACL rules and set them to the default rules config."""
self._acl = [None] * (self.max_acl_rules - 1)
for r_num, r_cfg in self._default_config.items():
self.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)
def _init_request_manager(self) -> RequestManager:
# TODO: Add src and dst wildcard masks as positional args in this request.
@@ -616,11 +569,6 @@ class RouteEntry(SimComponent):
metric: float = 0.0
"The cost metric for this route. Default is 0.0."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
self._original_values = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteEntry.
@@ -653,17 +601,6 @@ class RouteTable(SimComponent):
default_route: Optional[RouteEntry] = None
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.routes.clear()
self.routes = self._original_state["routes_orig"]
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteTable.
@@ -1104,8 +1041,6 @@ class Router(NetworkNode):
self._set_default_acl()
self.set_original_state()
def _install_system_software(self):
"""
Installs essential system software and network services on the router.
@@ -1131,20 +1066,7 @@ class Router(NetworkNode):
self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
def set_original_state(self):
"""
Sets or resets the router to its original configuration state.
This method is called to initialize the router's state or to revert it to a known good configuration during
network simulations or after configuration changes.
"""
self.acl.set_original_state()
self.route_table.set_original_state()
super().set_original_state()
vals_to_include = {"num_ports"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""
Resets the router's components for a new network simulation episode.
@@ -1154,13 +1076,10 @@ class Router(NetworkNode):
:param episode: The episode number for which the router is being reset.
"""
self.software_manager.arp.clear()
self.acl.reset_component_for_episode(episode)
self.route_table.reset_component_for_episode(episode)
for i, network_interface in self.network_interface.items():
network_interface.reset_component_for_episode(episode)
for i, _ in self.network_interface.items():
self.enable_port(i)
super().reset_component_for_episode(episode)
super().setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -1391,7 +1310,6 @@ class Router(NetworkNode):
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured Network Interface {network_interface}")
self.set_original_state()
def enable_port(self, port: int):
"""
@@ -1493,8 +1411,16 @@ class Router(NetworkNode):
subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")),
)
if "acl" in cfg:
router.acl._default_config = cfg["acl"] # save the config to allow resetting
router.acl._reset_rules_to_default() # read the config and apply rules
for r_num, r_cfg in cfg["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
router.route_table.add_route(

View File

@@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface):
_connected_node: Optional[Switch] = None
"The Switch to which the SwitchPort is connected."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -122,8 +122,6 @@ class WirelessRouter(Router):
self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0"))
self.set_original_state()
@property
def wireless_access_point(self) -> WirelessAccessPoint:
"""
@@ -166,7 +164,6 @@ class WirelessRouter(Router):
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured WAP {network_interface}")
self.set_original_state()
self.wireless_access_point.frequency = frequency # Set operating frequency
self.wireless_access_point.enable() # Re-enable the WAP with new settings

View File

@@ -21,13 +21,9 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
self.network.set_original_state()
def reset_component_for_episode(self, episode: int):
def setup_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.network.reset_component_for_episode(episode)
self.network.setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -38,12 +38,6 @@ class Application(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -31,20 +31,6 @@ class DatabaseClient(Application):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
self._query_success_tracker.clear()
def describe_state(self) -> Dict:
"""

View File

@@ -49,26 +49,6 @@ class DataManipulationBot(DatabaseClient):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"server_ip_address",
"payload",
"server_password",
"port_scan_p_of_success",
"data_manipulation_p_of_success",
"attack_stage",
"repeat",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -57,27 +57,6 @@ class DoSBot(DatabaseClient, Application):
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
def set_original_state(self):
"""Set the original state of the Denial of Service Bot."""
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"target_ip_address",
"target_port",
"payload",
"repeat",
"attack_stage",
"max_sessions",
"port_scan_p_of_success",
"dos_intensity",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -47,21 +47,8 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.set_original_state()
self.run()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -80,9 +67,6 @@ class WebBrowser(Application):
state["history"] = [hist_item.state() for hist_item in self.history]
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
def get_webpage(self, url: Optional[str] = None) -> bool:
"""
Retrieve the webpage.

View File

@@ -24,12 +24,6 @@ class Process(Software):
operating_state: ProcessOperatingState
"The current operating state of the Process."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -41,25 +41,6 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._create_db_file()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"password",
"connections",
"backup_server_ip",
"latest_backup_directory",
"latest_backup_file_name",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.clear_connections()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.

View File

@@ -29,18 +29,6 @@ class DNSClient(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_server"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_cache.clear()
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.

View File

@@ -28,20 +28,6 @@ class DNSServer(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_table"}
self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_table.clear()
for key, value in self._original_state["dns_table_orig"].items():
self.dns_table[key] = value
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.

View File

@@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"connected"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_password"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
self.clear_connections()
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -49,21 +49,12 @@ class NTPClient(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def send(
self,
payload: NTPPacket,
session_id: Optional[str] = None,
dest_ip_address: IPv4Address = None,
dest_port: [Port] = Port.NTP,
dest_port: Port = Port.NTP,
**kwargs,
) -> bool:
"""Requests NTP data from NTP server.

View File

@@ -34,16 +34,6 @@ class NTPServer(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including
resetting any stateful properties or statistics, and clearing any message
queues.
"""
pass
def receive(
self,
payload: NTPPacket,

View File

@@ -78,12 +78,6 @@ class Service(IOSoftware):
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))

View File

@@ -23,18 +23,6 @@ class WebServer(Service):
last_response_status_code: Optional[HttpStatusCode] = None
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"last_response_status_code"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -130,7 +118,7 @@ class WebServer(Service):
self.set_health_state(SoftwareHealthState.COMPROMISED)
return response
except Exception:
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response

View File

@@ -3,7 +3,7 @@ from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -13,6 +13,9 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
if TYPE_CHECKING:
from primaite.simulator.system.core.software_manager import SoftwareManager
class SoftwareType(Enum):
"""
@@ -84,7 +87,7 @@ class Software(SimComponent):
"The count of times the software has been scanned, defaults to 0."
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
software_manager: Any = None
software_manager: "SoftwareManager" = None
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."
@@ -97,19 +100,6 @@ class Software(SimComponent):
_patching_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"name",
"health_state_actual",
"health_state_visible",
"criticality",
"patching_count",
"scanning_count",
"revealed_to_red",
}
self._original_state = self.model_dump(include=vals_to_include)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -248,12 +238,6 @@ class IOSoftware(Software):
_connections: Dict[str, Dict] = {}
"Active connections."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""