Change agents list in game object to dictionary

This commit is contained in:
Marek Wolan
2024-02-26 10:26:28 +00:00
parent 63c9a36c30
commit e5982c4599
8 changed files with 331 additions and 225 deletions

View File

@@ -10,6 +10,8 @@ io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
game:
@@ -36,9 +38,9 @@ agents:
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
- node_name: client_2
applications:
- application_ref: client_2_web_browser
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -54,6 +56,31 @@ agents:
frequency: 4
variance: 3
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
@@ -72,7 +99,7 @@ agents:
- type: NODE_OS_SCAN
options:
nodes:
- node_ref: client_1
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
@@ -104,25 +131,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -137,23 +160,23 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
- node_hostname: domain_controller
nic_num: 1
- node_ref: web_server
- node_hostname: web_server
nic_num: 1
- node_ref: database_server
- node_hostname: database_server
nic_num: 1
- node_ref: backup_server
- node_hostname: backup_server
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 1
- node_ref: client_1
- node_hostname: client_1
nic_num: 1
- node_ref: client_2
- node_hostname: client_2
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 2
ics: null
@@ -184,10 +207,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -242,25 +265,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -271,22 +294,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -303,63 +326,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -407,123 +430,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
options:
nodes:
- node_ref: domain_controller
- node_ref: web_server
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_ref: database_service
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
weight: 0.34
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_hostname: client_2
agent_settings:
# ...
flatten_obs: true
- ref: defender_2
team: BLUE
@@ -537,25 +585,21 @@ agents:
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
- node_hostname: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
- service_name: DNSServer
- node_hostname: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
@@ -570,23 +614,23 @@ agents:
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
router_hostname: router_1
ip_address_order:
- node_ref: domain_controller
- node_hostname: domain_controller
nic_num: 1
- node_ref: web_server
- node_hostname: web_server
nic_num: 1
- node_ref: database_server
- node_hostname: database_server
nic_num: 1
- node_ref: backup_server
- node_hostname: backup_server
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 1
- node_ref: client_1
- node_hostname: client_1
nic_num: 1
- node_ref: client_2
- node_hostname: client_2
nic_num: 1
- node_ref: security_suite
- node_hostname: security_suite
nic_num: 2
ics: null
@@ -617,10 +661,10 @@ agents:
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
@@ -675,25 +719,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
@@ -704,22 +748,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -736,63 +780,63 @@ agents:
action: "NODE_RESET"
options:
node_id: 5
22:
22: # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
23: # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 2
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 3
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 4
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 5
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
position: 6
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
@@ -840,122 +884,148 @@ agents:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 1
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 2
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
nic_id: 0
options:
nodes:
- node_ref: domain_controller
- node_ref: web_server
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_ref: web_server_web_service
- node_ref: database_server
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_ref: database_service
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
weight: 0.34
options:
node_ref: database_server
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_ref: web_server
service_ref: web_server_web_service
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.33
options:
node_hostname: client_2
agent_settings:
# ...
flatten_obs: true
@@ -1032,12 +1102,13 @@ simulation:
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
@@ -1089,10 +1160,14 @@ simulation:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_1_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_1_dns_client
type: DNSClient
@@ -1109,6 +1184,13 @@ simulation:
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient

View File

@@ -79,11 +79,11 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.agents: Dict[str, AbstractAgent] = {}
"""Mapping from agent name to agent object."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.rl_agents: Dict[str, ProxyAgent] = {}
"""Subset of agents which are intended for reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
@@ -144,7 +144,7 @@ class PrimaiteGame:
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
for name, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
agent.reward_function.total_reward += agent.reward_function.current_reward
@@ -158,7 +158,7 @@ class PrimaiteGame:
"""
agent_actions = {}
for agent in self.agents:
for name, agent in self.agents.items():
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
@@ -396,7 +396,6 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
@@ -405,8 +404,7 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -415,8 +413,8 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
else:
_LOGGER.warning(f"agent type {agent_type} not found")
game.agents[agent_cfg["ref"]] = new_agent
return game

View File

@@ -60,7 +60,7 @@
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
@@ -88,6 +88,13 @@
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -54,7 +54,7 @@
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"cfg\":cfg}\n",
"env_config = cfg\n",
"\n",
"config = (\n",
" PPOConfig()\n",

View File

@@ -27,9 +27,7 @@
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
" cfg = yaml.safe_load(f)\n"
]
},
{
@@ -76,6 +74,13 @@
"source": [
"model.save(\"deleteme\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -25,12 +25,17 @@ class PrimaiteGymEnv(gymnasium.Env):
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Current game."""
self.agent: ProxyAgent = self.game.rl_agents[0]
"""The agent within the game that is controlled by the RL algorithm."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
return self.game.rl_agents[self._agent_name]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
@@ -71,11 +76,10 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Reset the environment."""
print(
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
f"avg. reward: {self.agent.reward_function.total_reward}"
)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.agent = self.game.rl_agents[0]
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
@@ -112,11 +116,10 @@ class PrimaiteRayEnv(gymnasium.Env):
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
self.env = PrimaiteGymEnv(game_config=env_config)
self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -138,13 +141,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
:type env_config: Dict
"""
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
self._agent_ids = list(self.agents.keys())
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.episode_counter: int = 0
"""Current episode number."""
self.terminateds = set()
self.truncateds = set()
@@ -159,9 +165,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
)
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game.reset()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -182,7 +195,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -196,7 +209,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {"agent_actions": agent_actions}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
@@ -222,8 +235,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for name, agent in self.agents.items():
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs

View File

@@ -510,6 +510,6 @@ def game_and_agent():
reward_function=reward_function,
)
game.agents.append(test_agent)
game.agents["test_agent"] = test_agent
return (game, test_agent)

View File

@@ -42,20 +42,20 @@ def test_example_config():
assert len(game.agents) == 4 # red, blue and 2 green agents
# green agent 1
assert game.agents[0].agent_name == "client_2_green_user"
assert isinstance(game.agents[0], RandomAgent)
assert "client_2_green_user" in game.agents
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
# green agent 2
assert game.agents[1].agent_name == "client_1_green_user"
assert isinstance(game.agents[1], RandomAgent)
assert "client_1_green_user" in game.agents
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
# red agent
assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot"
assert isinstance(game.agents[2], DataManipulationAgent)
assert "client_1_data_manipulation_red_bot" in game.agents
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
# blue agent
assert game.agents[3].agent_name == "defender"
assert isinstance(game.agents[3], ProxyAgent)
assert "defender" in game.agents
assert isinstance(game.agents["defender"], ProxyAgent)
network: Network = game.simulation.network