Files
PrimAITE/src/primaite/game/session.py

463 lines
21 KiB
Python
Raw Normal View History

2023-10-19 01:56:40 +01:00
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
2023-10-06 10:36:29 +01:00
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
2023-10-06 10:36:29 +01:00
import gymnasium
from gymnasium.core import ActType, ObsType
2023-10-06 10:36:29 +01:00
from pydantic import BaseModel
2023-10-09 18:35:30 +01:00
from primaite import getLogger
2023-10-06 10:36:29 +01:00
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
2023-10-06 20:32:52 +01:00
from primaite.game.agent.rewards import RewardFunction
from primaite.game.policy.policy import PolicyABC
2023-10-06 10:36:29 +01:00
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
2023-10-23 16:26:34 +01:00
from primaite.simulator.system.applications.application import Application
2023-10-06 10:36:29 +01:00
from primaite.simulator.system.applications.database_client import DatabaseClient
2023-10-23 16:26:34 +01:00
from primaite.simulator.system.applications.web_browser import WebBrowser
2023-10-09 18:35:30 +01:00
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
2023-10-06 10:36:29 +01:00
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
2023-10-23 16:26:34 +01:00
from primaite.simulator.system.services.web_server.web_server import WebServer
2023-10-06 10:36:29 +01:00
_LOGGER = getLogger(__name__)
2023-10-19 01:56:40 +01:00
2023-11-15 12:52:18 +00:00
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
2023-11-15 12:52:18 +00:00
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
2023-11-15 12:52:18 +00:00
truncated = False
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
2023-11-15 12:52:18 +00:00
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
2023-11-15 12:52:18 +00:00
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
2023-11-15 12:52:18 +00:00
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
2023-10-06 10:36:29 +01:00
class PrimaiteSessionOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
2023-10-19 01:56:40 +01:00
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
2023-10-19 01:56:40 +01:00
2023-10-06 10:36:29 +01:00
ports: List[str]
protocols: List[str]
2023-10-02 17:21:43 +01:00
2023-10-19 01:56:40 +01:00
2023-10-10 09:48:04 +01:00
class TrainingOptions(BaseModel):
2023-10-19 01:56:40 +01:00
"""Options for training the RL agent."""
2023-10-10 09:48:04 +01:00
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
2023-10-19 01:56:40 +01:00
seed: Optional[int]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int = 0
n_eval_steps: Optional[int] = None
deterministic_eval: bool
n_agents: int
agent_references: List[str]
2023-10-10 09:48:04 +01:00
2023-10-02 17:21:43 +01:00
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
2023-10-19 01:56:40 +01:00
2023-10-02 17:21:43 +01:00
def __init__(self):
"""Initialise a PrimaiteSession object."""
2023-10-02 17:21:43 +01:00
self.simulation: Simulation = Simulation()
2023-10-19 01:56:40 +01:00
"""Simulation object with which the agents will interact."""
2023-10-06 10:36:29 +01:00
self.agents: List[AbstractAgent] = []
2023-10-19 01:56:40 +01:00
"""List of agents."""
2023-11-15 12:52:18 +00:00
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
2023-10-06 10:36:29 +01:00
self.step_counter: int = 0
2023-10-19 01:56:40 +01:00
"""Current timestep within the episode."""
2023-10-06 10:36:29 +01:00
self.episode_counter: int = 0
2023-10-19 01:56:40 +01:00
"""Current episode number."""
2023-10-06 10:36:29 +01:00
self.options: PrimaiteSessionOptions
2023-10-19 01:56:40 +01:00
"""Special options that apply for the entire game."""
2023-10-10 09:48:04 +01:00
self.training_options: TrainingOptions
2023-10-19 01:56:40 +01:00
"""Options specific to agent training."""
2023-10-02 17:21:43 +01:00
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
2023-10-19 01:56:40 +01:00
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
2023-10-19 01:56:40 +01:00
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
2023-10-23 16:26:34 +01:00
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
2023-10-19 01:56:40 +01:00
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
2023-11-15 12:52:18 +00:00
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
2023-10-19 15:34:46 +01:00
def start_session(self) -> None:
"""Commence the training session."""
2023-11-13 17:12:50 +00:00
n_learn_steps = self.training_options.n_learn_steps
n_learn_episodes = self.training_options.n_learn_episodes
2023-11-13 17:12:50 +00:00
n_eval_steps = self.training_options.n_eval_steps
n_eval_episodes = self.training_options.n_eval_episodes
2023-11-13 17:12:50 +00:00
deterministic_eval = True # TODO: get this value from config
if n_learn_episodes > 0:
2023-11-13 17:12:50 +00:00
self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps)
if n_eval_episodes > 0:
2023-11-13 17:12:50 +00:00
self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval)
2023-10-10 09:48:04 +01:00
2023-10-02 17:21:43 +01:00
def step(self):
2023-10-19 01:56:40 +01:00
"""
Perform one step of the simulation/agent loop.
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
agent. The steps are as follows:
1. The simulation state is updated.
2. The simulation state is sent to each agent.
3. Each agent converts the state to an observation and calculates a reward.
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
2023-10-19 01:56:40 +01:00
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
2023-10-02 17:21:43 +01:00
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
2023-10-02 17:21:43 +01:00
for agent in self.agents:
agent.update_observation(state)
agent.update_reward(state)
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
def advance_timestep(self) -> None:
"""Advance timestep."""
2023-10-02 17:21:43 +01:00
self.simulation.apply_timestep(self.step_counter)
self.step_counter += 1
def reset(self) -> None:
2023-10-19 01:56:40 +01:00
"""Reset the session, this will reset the simulation."""
return NotImplemented
2023-10-10 21:01:09 +01:00
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
2023-10-19 01:56:40 +01:00
return NotImplemented
2023-10-10 21:01:09 +01:00
2023-10-06 10:36:29 +01:00
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
2023-10-19 01:56:40 +01:00
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2023-10-19 01:56:40 +01:00
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
# TODO: create documentation page and add links to it here.
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
"""
2023-10-06 10:36:29 +01:00
sess = cls()
2023-10-06 20:32:52 +01:00
sess.options = PrimaiteSessionOptions(
2023-10-09 18:35:30 +01:00
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
2023-10-06 20:32:52 +01:00
)
2023-10-19 01:56:40 +01:00
sess.training_options = TrainingOptions(**cfg["training_config"])
2023-10-06 10:36:29 +01:00
sim = sess.simulation
net = sim.network
sess.ref_map_nodes: Dict[str, Node] = {}
sess.ref_map_services: Dict[str, Service] = {}
sess.ref_map_links: Dict[str, Link] = {}
2023-10-06 10:36:29 +01:00
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
for node_cfg in nodes_cfg:
node_ref = node_cfg["ref"]
n_type = node_cfg["type"]
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg["dns_server"],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg.get("dns_server"),
)
elif n_type == "switch":
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
elif n_type == "router":
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
# this: 'r_cfg.get('src_port')'
# Port/IPProtocol. TODO Refactor
new_node.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
position=r_num,
)
else:
print("invalid node type")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
service_types_mapping = {
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
"DNSServer": DNSServer,
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
2023-10-23 16:26:34 +01:00
"WebServer": WebServer,
2023-10-06 10:36:29 +01:00
"DataManipulationBot": DataManipulationBot,
}
if service_type in service_types_mapping:
2023-10-23 16:26:34 +01:00
print(f"installing {service_type} on node {new_node.hostname}")
2023-10-06 10:36:29 +01:00
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
2023-10-08 17:57:45 +01:00
sess.ref_map_services[service_ref] = new_service
2023-10-06 10:36:29 +01:00
else:
print(f"service type not found {service_type}")
# service-dependent options
if service_type == "DatabaseClient":
if "options" in service_cfg:
opt = service_cfg["options"]
if "db_server_ip" in opt:
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
if service_type == "DNSServer":
if "options" in service_cfg:
opt = service_cfg["options"]
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
2023-10-23 16:26:34 +01:00
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
application_types_mapping = {
"WebBrowser": WebBrowser,
}
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
2023-10-06 10:36:29 +01:00
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
net.add_node(new_node)
new_node.power_on()
2023-10-19 01:56:40 +01:00
sess.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
2023-10-06 10:36:29 +01:00
# 2. create links between nodes
for link_cfg in links_cfg:
2023-10-08 17:57:45 +01:00
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
2023-10-06 10:36:29 +01:00
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
if isinstance(node_b, Switch):
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
2023-10-08 17:57:45 +01:00
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
2023-10-06 10:36:29 +01:00
# 3. create agents
game_cfg = cfg["game_config"]
agents_cfg = game_cfg["agents"]
for agent_cfg in agents_cfg:
2023-10-19 01:56:40 +01:00
agent_ref = agent_cfg["ref"] # noqa: F841
2023-10-06 10:36:29 +01:00
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
2023-10-08 17:57:45 +01:00
2023-10-06 10:36:29 +01:00
# CREATE ACTION SPACE
2023-10-09 18:35:30 +01:00
action_space_cfg["options"]["node_uuids"] = []
2023-10-06 20:32:52 +01:00
# if a list of nodes is defined, convert them from node references to node UUIDs
2023-10-09 18:35:30 +01:00
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
2023-10-06 20:32:52 +01:00
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
# dynamically, so we have to translate node references to uuids before passing this config on.
2023-10-09 18:35:30 +01:00
if "action_list" in action_space_cfg:
for action_config in action_space_cfg["action_list"]:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
2023-10-06 20:32:52 +01:00
2023-10-06 10:36:29 +01:00
action_space = ActionManager.from_config(sess, action_space_cfg)
# CREATE REWARD FUNCTION
2023-10-12 09:59:45 +01:00
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
2023-10-06 10:36:29 +01:00
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
2023-10-09 18:35:30 +01:00
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
2023-10-06 20:32:52 +01:00
sess.agents.append(new_agent)
2023-11-15 12:52:18 +00:00
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
2023-10-09 18:35:30 +01:00
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
2023-11-15 12:52:18 +00:00
sess.rl_agents.append(new_agent)
2023-10-06 10:36:29 +01:00
elif agent_type == "RedDatabaseCorruptingAgent":
2023-10-09 18:35:30 +01:00
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
2023-10-06 10:36:29 +01:00
else:
print("agent type not found")
2023-11-15 12:52:18 +00:00
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
2023-10-06 10:36:29 +01:00
return sess