Files
PrimAITE/src/primaite/game/game.py

581 lines
28 KiB
Python
Raw Normal View History

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
2023-11-26 23:29:14 +00:00
"""PrimAITE game - Encapsulates the simulation and agents."""
2023-10-06 10:36:29 +01:00
from ipaddress import IPv4Address
2024-03-13 12:11:50 +00:00
from typing import Dict, List, Optional
2023-10-06 10:36:29 +01:00
import numpy as np
from pydantic import BaseModel, ConfigDict
2023-10-06 10:36:29 +01:00
from primaite import DEFAULT_BANDWIDTH, getLogger
2023-10-06 10:36:29 +01:00
from primaite.game.agent.actions import ActionManager
2024-02-27 13:30:16 +00:00
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
2024-03-11 22:53:39 +00:00
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
2024-04-15 11:50:08 +01:00
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
2024-03-11 22:53:39 +00:00
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
2024-07-17 10:43:04 +01:00
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import NmneData, store_nmne_config
from primaite.simulator.network.transmission.transport_layer import Port
2023-10-06 10:36:29 +01:00
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( # noqa: F401
DataManipulationBot,
)
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot # noqa: F401
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript # noqa: F401
from primaite.simulator.system.applications.web_browser import WebBrowser # noqa: F401
2023-10-09 18:35:30 +01:00
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
2023-10-23 16:26:34 +01:00
from primaite.simulator.system.services.web_server.web_server import WebServer
2023-10-06 10:36:29 +01:00
_LOGGER = getLogger(__name__)
SERVICE_TYPES_MAPPING = {
"DNSClient": DNSClient,
"DNSServer": DNSServer,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
}
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
2023-10-19 01:56:40 +01:00
2023-11-26 23:29:14 +00:00
class PrimaiteGameOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
2023-10-19 01:56:40 +01:00
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
2023-10-19 01:56:40 +01:00
model_config = ConfigDict(extra="forbid")
2023-11-26 23:29:14 +00:00
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
2023-10-06 10:36:29 +01:00
ports: List[str]
"""A whitelist of available ports in the simulation."""
2023-10-06 10:36:29 +01:00
protocols: List[str]
"""A whitelist of available protocols in the simulation."""
thresholds: Optional[Dict] = {}
"""A dict containing the thresholds used for determining what is acceptable during observations."""
2023-10-02 17:21:43 +01:00
2023-10-19 01:56:40 +01:00
2023-11-26 23:29:14 +00:00
class PrimaiteGame:
"""
Primaite game encapsulates the simulation and agents which interact with it.
2023-10-02 17:21:43 +01:00
2023-11-26 23:29:14 +00:00
Provides main logic loop for the game. However, it does not provide policy training, or a gymnasium environment.
"""
2023-10-19 01:56:40 +01:00
2023-10-02 17:21:43 +01:00
def __init__(self):
2023-11-26 23:29:14 +00:00
"""Initialise a PrimaiteGame object."""
2023-10-02 17:21:43 +01:00
self.simulation: Simulation = Simulation()
2023-10-19 01:56:40 +01:00
"""Simulation object with which the agents will interact."""
self.agents: Dict[str, AbstractAgent] = {}
"""Mapping from agent name to agent object."""
self.rl_agents: Dict[str, ProxyAgent] = {}
"""Subset of agents which are intended for reinforcement learning."""
2023-11-15 12:52:18 +00:00
2023-10-06 10:36:29 +01:00
self.step_counter: int = 0
2023-10-19 01:56:40 +01:00
"""Current timestep within the episode."""
2023-11-26 23:29:14 +00:00
self.options: PrimaiteGameOptions
2023-10-19 01:56:40 +01:00
"""Special options that apply for the entire game."""
2023-12-04 10:42:20 +00:00
self.save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
2024-03-11 22:53:39 +00:00
self._reward_calculation_order: List[str] = [name for name in self.agents]
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
2024-07-02 11:15:31 +01:00
self.nmne_config: NmneData = None
""" Config data from Number of Malicious Network Events."""
2023-10-02 17:21:43 +01:00
def step(self):
2023-10-19 01:56:40 +01:00
"""
Perform one step of the simulation/agent loop.
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
agent. The steps are as follows:
1. The simulation state is updated.
2. The simulation state is sent to each agent.
3. Each agent converts the state to an observation and calculates a reward.
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
2023-10-19 01:56:40 +01:00
"""
2023-11-26 23:29:14 +00:00
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
2023-10-02 17:21:43 +01:00
2024-04-15 11:50:08 +01:00
self.pre_timestep()
2024-03-14 14:33:04 +00:00
if self.step_counter == 0:
state = self.get_sim_state()
for agent in self.agents.values():
agent.update_observation(state=state)
# Apply all actions to simulation as requests
2024-03-14 14:33:04 +00:00
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
2024-03-11 20:10:08 +00:00
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state, and the response from the last action
2024-03-14 14:33:04 +00:00
self.update_agents(state=sim_state)
2024-03-11 20:10:08 +00:00
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
2024-03-11 22:53:39 +00:00
for agent_name in self._reward_calculation_order:
agent = self.agents[agent_name]
2024-03-11 20:10:08 +00:00
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
2024-05-31 15:00:18 +01:00
agent.save_reward_to_history()
2024-03-11 22:53:39 +00:00
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward
2024-03-11 20:10:08 +00:00
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
2024-02-29 10:14:31 +00:00
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
2024-03-11 20:10:08 +00:00
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
if SIM_OUTPUT.save_agent_logs:
agent.logger.debug(f"Chosen Action: {action_choice}")
2024-03-11 20:10:08 +00:00
request = agent.format_request(action_choice, parameters)
2024-03-08 15:57:43 +00:00
response = self.simulation.apply_request(request)
2024-03-11 20:10:08 +00:00
agent.process_action_response(
timestep=self.step_counter,
action=action_choice,
parameters=parameters,
request=request,
response=response,
)
2024-04-15 11:50:08 +01:00
def pre_timestep(self) -> None:
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
self.simulation.pre_timestep(self.step_counter)
def advance_timestep(self) -> None:
"""Advance timestep."""
2023-10-02 17:21:43 +01:00
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.update_agent_loggers()
self.simulation.apply_timestep(self.step_counter)
def update_agent_loggers(self) -> None:
"""Updates Agent Loggers with new timestep."""
for agent in self.agents.values():
agent.logger.update_timestep(self.step_counter)
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
2023-11-26 23:29:14 +00:00
max_steps = self.options.max_episode_length
if current_step >= max_steps:
return True
return False
2023-10-02 17:21:43 +01:00
def action_mask(self, agent_name: str) -> np.ndarray:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
agent = self.agents[agent_name]
mask = [True] * len(agent.action_manager.action_map)
for i, action in agent.action_manager.action_map.items():
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.simulation._request_manager.check_valid(request, {})
2024-07-10 11:01:42 +01:00
return np.asarray(mask, dtype=np.int8)
def close(self) -> None:
2023-11-26 23:29:14 +00:00
"""Close the game, this will close the simulation."""
2023-10-19 01:56:40 +01:00
return NotImplemented
2023-10-10 21:01:09 +01:00
def setup_for_episode(self, episode: int) -> None:
"""Perform any final configuration of components to make them ready for the game to start."""
self.simulation.setup_for_episode(episode=episode)
2023-10-06 10:36:29 +01:00
@classmethod
2023-11-26 23:29:14 +00:00
def from_config(cls, cfg: Dict) -> "PrimaiteGame":
"""Create a PrimaiteGame object from a config dictionary.
2023-10-19 01:56:40 +01:00
The config dictionary should have the following top-level keys:
2024-04-16 11:26:17 +01:00
1. io_settings: options for logging data during training
2. game_config: options for the game itself, such as agents.
2023-10-19 01:56:40 +01:00
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
# TODO: create documentation page and add links to it here.
:param cfg: The config dictionary.
:type cfg: dict
2023-11-26 23:29:14 +00:00
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
2023-10-19 01:56:40 +01:00
"""
2023-11-26 23:29:14 +00:00
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
2023-12-04 10:42:20 +00:00
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
2023-11-26 23:29:14 +00:00
# 1. create simulation
sim = game.simulation
2023-10-06 10:36:29 +01:00
net = sim.network
simulation_config = cfg.get("simulation", {})
network_config = simulation_config.get("network", {})
airspace_cfg = network_config.get("airspace", {})
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()}
net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {}))
2023-10-06 10:36:29 +01:00
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
2023-10-06 10:36:29 +01:00
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")),
default_gateway=node_cfg.get("default_gateway"),
dns_server=node_cfg.get("dns_server", None),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
2023-10-06 10:36:29 +01:00
)
elif n_type == "switch":
new_node = Switch(
hostname=node_cfg["hostname"],
num_ports=int(node_cfg.get("num_ports", "8")),
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
2023-10-06 10:36:29 +01:00
elif n_type == "router":
2024-01-23 14:34:05 +00:00
new_node = Router.from_config(node_cfg)
elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg, airspace=net.airspace)
elif n_type == "printer":
new_node = Printer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
2024-04-15 11:50:08 +01:00
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
2023-10-06 10:36:29 +01:00
else:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
raise ValueError(msg)
2023-10-06 10:36:29 +01:00
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
2023-10-06 10:36:29 +01:00
service_type = service_cfg["type"]
if service_type in SERVICE_TYPES_MAPPING:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type])
2023-10-06 10:36:29 +01:00
new_service = new_node.software_manager.software[service_type]
# fixing duration for the service
2024-07-03 10:34:44 +01:00
if "fix_duration" in service_cfg.get("options", {}):
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
# start the service
new_service.start()
2023-10-06 10:36:29 +01:00
else:
msg = f"Configuration contains an invalid service type: {service_type}"
_LOGGER.error(msg)
raise ValueError(msg)
2023-10-06 10:36:29 +01:00
# service-dependent options
if service_type == "DNSClient":
2023-10-06 10:36:29 +01:00
if "options" in service_cfg:
opt = service_cfg["options"]
if "dns_server" in opt:
new_service.dns_server = IPv4Address(opt["dns_server"])
2023-10-06 10:36:29 +01:00
if service_type == "DNSServer":
if "options" in service_cfg:
opt = service_cfg["options"]
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, IPv4Address(ip))
if service_type == "DatabaseService":
if "options" in service_cfg:
opt = service_cfg["options"]
2024-03-01 16:02:27 +00:00
new_service.password = opt.get("db_password", None)
if "backup_server_ip" in opt:
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
if service_type == "FTPServer":
if "options" in service_cfg:
opt = service_cfg["options"]
new_service.server_password = opt.get("server_password")
if service_type == "NTPClient":
if "options" in service_cfg:
opt = service_cfg["options"]
new_service.ntp_server = IPv4Address(opt.get("ntp_server_ip"))
2023-10-23 16:26:34 +01:00
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
new_application = None
2023-10-23 16:26:34 +01:00
application_type = application_cfg["type"]
if application_type in Application._application_registry:
new_node.software_manager.install(Application._application_registry[application_type])
new_application = new_node.software_manager.software[application_type] # grab the instance
# fixing duration for the application
2024-07-03 10:34:44 +01:00
if "fix_duration" in application_cfg.get("options", {}):
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
2023-10-23 16:26:34 +01:00
else:
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# run the application
new_application.run()
if application_type == "DataManipulationBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
2023-11-26 23:29:14 +00:00
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
2024-02-23 16:49:01 +00:00
payload=opt.get("payload", "DELETE"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
2024-04-15 11:50:08 +01:00
elif application_type == "RansomwareScript":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
2024-04-15 11:50:08 +01:00
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("db_server_ip")),
server_password=opt.get("server_password"),
)
elif application_type == "WebBrowser":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.target_url = opt.get("target_url")
elif application_type == "DoSBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
target_ip_address=IPv4Address(opt.get("target_ip_address")),
target_port=Port(opt.get("target_port", Port.POSTGRES_SERVER.value)),
payload=opt.get("payload"),
repeat=bool(opt.get("repeat")),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
dos_intensity=float(opt.get("dos_intensity", "1.0")),
max_sessions=int(opt.get("max_sessions", "1000")),
)
if "network_interfaces" in node_cfg:
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
2023-10-06 10:36:29 +01:00
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
# temporarily set to 0 so all nodes are initially on
new_node.start_up_duration = 0
new_node.shut_down_duration = 0
2023-10-06 10:36:29 +01:00
net.add_node(new_node)
# run through the power on step if the node is to be turned on at the start
if new_node.operating_state == NodeOperatingState.ON:
new_node.power_on()
2023-10-06 10:36:29 +01:00
# set start up and shut down duration
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3))
2023-10-06 10:36:29 +01:00
# 2. create links between nodes
for link_cfg in links_cfg:
2024-04-03 22:20:33 +01:00
node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"])
node_b = net.get_node_by_hostname(link_cfg["endpoint_b_hostname"])
bandwidth = link_cfg.get("bandwidth", DEFAULT_BANDWIDTH) # default value if not configured
2024-04-15 12:40:50 +01:00
2023-10-06 10:36:29 +01:00
if isinstance(node_a, Switch):
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
2023-10-06 10:36:29 +01:00
else:
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
2023-10-06 10:36:29 +01:00
if isinstance(node_b, Switch):
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
2023-10-06 10:36:29 +01:00
else:
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth)
2023-10-06 10:36:29 +01:00
# 3. create agents
agents_cfg = cfg.get("agents", [])
2023-10-06 10:36:29 +01:00
for agent_cfg in agents_cfg:
2023-10-19 01:56:40 +01:00
agent_ref = agent_cfg["ref"] # noqa: F841
2023-10-06 10:36:29 +01:00
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg)
2023-10-08 17:57:45 +01:00
2023-10-06 10:36:29 +01:00
# CREATE ACTION SPACE
2023-11-26 23:29:14 +00:00
action_space = ActionManager.from_config(game, action_space_cfg)
2023-10-06 10:36:29 +01:00
# CREATE REWARD FUNCTION
2024-02-06 17:32:15 +00:00
reward_function = RewardFunction.from_config(reward_function_cfg)
2023-10-06 10:36:29 +01:00
# CREATE AGENT
2024-03-04 09:58:57 +00:00
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
settings = agent_cfg.get("agent_settings", {})
2024-02-27 13:30:16 +00:00
new_agent = ProbabilisticAgent(
2023-10-09 18:35:30 +01:00
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
2024-02-06 17:32:15 +00:00
reward_function=reward_function,
2024-02-27 13:30:16 +00:00
settings=settings,
2023-10-09 18:35:30 +01:00
)
2024-04-15 11:50:08 +01:00
elif agent_type == "PeriodicAgent":
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
new_agent = PeriodicAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
2024-04-15 12:40:50 +01:00
2023-11-15 12:52:18 +00:00
elif agent_type == "ProxyAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = ProxyAgent(
2023-10-09 18:35:30 +01:00
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
2024-02-06 17:32:15 +00:00
reward_function=reward_function,
agent_settings=agent_settings,
2023-10-09 18:35:30 +01:00
)
game.rl_agents[agent_cfg["ref"]] = new_agent
2023-10-06 10:36:29 +01:00
elif agent_type == "RedDatabaseCorruptingAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = DataManipulationAgent(
2023-10-09 18:35:30 +01:00
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
2024-02-06 17:32:15 +00:00
reward_function=reward_function,
agent_settings=agent_settings,
2023-10-09 18:35:30 +01:00
)
2024-04-15 11:50:08 +01:00
elif agent_type == "TAP001":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = TAP001(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
2023-10-06 10:36:29 +01:00
else:
2024-02-29 13:22:05 +00:00
msg = f"Configuration error: {agent_type} is not a valid agent type."
2024-02-29 10:14:31 +00:00
_LOGGER.error(msg)
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
2023-10-06 10:36:29 +01:00
2024-03-11 22:53:39 +00:00
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
2024-03-14 14:33:04 +00:00
game.update_agents(game.get_sim_state())
2023-11-26 23:29:14 +00:00
return game
2024-03-11 22:53:39 +00:00
def setup_reward_sharing(self):
"""Do necessary setup to enable reward sharing between agents.
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
sharing their reward. This callback is provided by this setup method.
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
that rely on the value of another reward are evaluated later.
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
"""
# construct dependency graph in the reward sharing between agents.
graph = {}
for name, agent in self.agents.items():
graph[name] = set()
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
2024-03-12 11:40:26 +00:00
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
2024-03-11 22:53:39 +00:00
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
if graph_has_cycle(graph):
raise RuntimeError(
(
"Detected cycle in agent reward sharing. Check the agent reward function ",
"configuration: reward sharing can only go one way.",
)
)
# sort the agents so the rewards that depend on other rewards are always evaluated later
self._reward_calculation_order = topological_sort(graph)