Files
PrimAITE/src/primaite/game/session.py
2023-11-17 10:20:26 +00:00

533 lines
23 KiB
Python

"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
import enlighten
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
progress_bar_manager = enlighten.get_manager()
_LOGGER = getLogger(__name__)
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
model_config = ConfigDict(extra="forbid")
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
def __init__(self):
"""Initialise a PrimaiteSession object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.training_progress_bar: Optional[enlighten.Counter] = None
"""training steps counter"""
self.eval_progress_bar: Optional[enlighten.Counter] = None
"""evaluation episodes counter"""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
self.training_progress_bar = progress_bar_manager.counter(
total=n_learn_episodes * max_steps_per_episode, desc="Training steps"
)
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.eval_progress_bar = progress_bar_manager.counter(total=n_eval_episodes, desc="Evaluation episodes")
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
Perform one step of the simulation/agent loop.
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
agent. The steps are as follows:
1. The simulation state is updated.
2. The simulation state is sent to each agent.
3. Each agent converts the state to an observation and calculates a reward.
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
agent.update_observation(state)
agent.update_reward(state)
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
def advance_timestep(self) -> None:
"""Advance timestep."""
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.simulation.apply_timestep(self.step_counter)
if self.training_progress_bar and self.mode == SessionMode.TRAIN:
self.training_progress_bar.update()
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
if self.eval_progress_bar and self.mode == SessionMode.EVAL:
self.eval_progress_bar.update()
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
# TODO: create documentation page and add links to it here.
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
"""
sess = cls()
sess.options = PrimaiteSessionOptions(
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
net = sim.network
sess.ref_map_nodes: Dict[str, Node] = {}
sess.ref_map_services: Dict[str, Service] = {}
sess.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
for node_cfg in nodes_cfg:
node_ref = node_cfg["ref"]
n_type = node_cfg["type"]
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg["dns_server"],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg.get("dns_server"),
)
elif n_type == "switch":
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
elif n_type == "router":
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
# this: 'r_cfg.get('src_port')'
# Port/IPProtocol. TODO Refactor
new_node.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
position=r_num,
)
else:
print("invalid node type")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
service_types_mapping = {
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
"DNSServer": DNSServer,
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"DataManipulationBot": DataManipulationBot,
}
if service_type in service_types_mapping:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
sess.ref_map_services[service_ref] = new_service
else:
print(f"service type not found {service_type}")
# service-dependent options
if service_type == "DatabaseClient":
if "options" in service_cfg:
opt = service_cfg["options"]
if "db_server_ip" in opt:
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
if service_type == "DNSServer":
if "options" in service_cfg:
opt = service_cfg["options"]
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
application_types_mapping = {
"WebBrowser": WebBrowser,
}
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
if isinstance(node_b, Switch):
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
# 3. create agents
game_cfg = cfg["game_config"]
agents_cfg = game_cfg["agents"]
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"] # noqa: F841
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
# dynamically, so we have to translate node references to uuids before passing this config on.
if "action_list" in action_space_cfg:
for action_config in action_space_cfg["action_list"]:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess