diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index a0e9667e..6813161d 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -33,7 +33,12 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenUC2Agent + type: probabilistic_agent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -69,15 +74,14 @@ agents: reward_components: - type: DUMMY + - ref: client_1_green_user + team: GREEN + type: probabilistic_agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - - - ref: client_1_green_user - team: GREEN - type: GreenUC2Agent observation_space: type: UC2GreenObservation action_space: @@ -113,11 +117,6 @@ agents: reward_components: - type: DUMMY - agent_settings: - action_probabilities: - 0: 0.3 - 1: 0.6 - 2: 0.1 diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 93019c9d..df6130d1 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -27,7 +27,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 1793d420..18cb6262 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -607,7 +607,6 @@ class ActionManager: def __init__( self, - game: "PrimaiteGame", # reference to game for information lookup actions: List[Dict], # stores list of actions available to agent nodes: List[Dict], # extra configuration for each node max_folders_per_node: int = 2, # allows calculating shape @@ -618,7 +617,7 @@ class ActionManager: max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port - ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. + ip_address_list: List[str] = [], # to allow us to map an index to an ip address. act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions ) -> None: """Init method for ActionManager. @@ -649,7 +648,6 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.game: "PrimaiteGame" = game self.node_names: List[str] = [n["node_name"] for n in nodes] """List of node names in this action space. The list order is the mapping between node index and node name.""" self.application_names: List[List[str]] = [] @@ -707,25 +705,7 @@ class ActionManager: self.protocols: List[str] = protocols self.ports: List[str] = ports - self.ip_address_list: List[str] - - # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from - # the nodes in the simulation. - # TODO: refactor. Options: - # 1: This should be pulled out into it's own function for clarity - # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to - # go through the nodes here. - if ip_address_list is not None: - self.ip_address_list = ip_address_list - else: - self.ip_address_list = [] - for node_name in self.node_names: - node_obj = self.game.simulation.network.get_node_by_hostname(node_name) - if node_obj is None: - continue - network_interfaces = node_obj.network_interfaces - for nic_uuid, nic_obj in network_interfaces.items(): - self.ip_address_list.append(nic_obj.ip_address) + self.ip_address_list: List[str] = ip_address_list # action_args are settings which are applied to the action space as a whole. global_action_args = { @@ -958,6 +938,12 @@ class ActionManager: :return: The constructed ActionManager. :rtype: ActionManager """ + # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from + # the nodes in the simulation. + # TODO: refactor. Options: + # 1: This should be pulled out into it's own function for clarity + # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to + # go through the nodes here. ip_address_order = cfg["options"].pop("ip_address_order", {}) ip_address_list = [] for entry in ip_address_order: @@ -967,13 +953,22 @@ class ActionManager: ip_address = node_obj.network_interface[nic_num].ip_address ip_address_list.append(ip_address) + if not ip_address_list: + node_names = [n["node_name"] for n in cfg.get("nodes", {})] + for node_name in node_names: + node_obj = game.simulation.network.get_node_by_hostname(node_name) + if node_obj is None: + continue + network_interfaces = node_obj.network_interfaces + for nic_uuid, nic_obj in network_interfaces.items(): + ip_address_list.append(nic_obj.ip_address) + obj = cls( - game=game, actions=cfg["action_list"], **cfg["options"], protocols=game.options.protocols, ports=game.options.ports, - ip_address_list=ip_address_list or None, + ip_address_list=ip_address_list, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 126c55ec..b5de9a5a 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -1,5 +1,5 @@ import random -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from gymnasium.core import ObsType @@ -26,7 +26,7 @@ class DataManipulationAgent(AbstractScriptedAgent): ) self.next_execution_timestep = timestep + random_timestep_increment - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -36,12 +36,10 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - current_timestep = self.action_manager.game.step_counter - - if current_timestep < self.next_execution_timestep: + if timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} - self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) + self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 276715f7..4f434bad 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -112,7 +112,7 @@ class AbstractAgent(ABC): return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -122,6 +122,8 @@ class AbstractAgent(ABC): :type obs: ObsType :param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted? :type reward: float, optional + :param timestep: The current timestep in the simulation, used for non-RL agents. Optional + :type timestep: int :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ @@ -144,13 +146,13 @@ class AbstractAgent(ABC): class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" - ... + pass class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -183,7 +185,7 @@ class ProxyAgent(AbstractAgent): self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py index a88e563d..28d94062 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents.py @@ -11,30 +11,39 @@ from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -class GreenUC2Agent(AbstractScriptedAgent): - """Scripted agent which attempts to send web requests to a target node.""" +class ProbabilisticAgent(AbstractScriptedAgent): + """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" + + class Settings(pydantic.BaseModel): + """Config schema for Probabilistic agent settings.""" - class GreenUC2AgentSettings(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra="forbid") + """Strict validation.""" action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" random_seed: Optional[int] = None + """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" + # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way + # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. @pydantic.field_validator("action_probabilities", mode="after") @classmethod def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: + """Make sure the probabilities sum to 1.""" if not abs(sum(v.values()) - 1) < 1e-6: - raise ValueError(f"Green action probabilities must sum to 1") + raise ValueError("Green action probabilities must sum to 1") return v @pydantic.field_validator("action_probabilities", mode="after") @classmethod def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]: + """Ensure that the keys of the probability dictionary cover all integers from 0 to N.""" if not all((i in v) for i in range(len(v))): raise ValueError( "Green action probabilities must be defined as a mapping where the keys are consecutive integers " "from 0 to N." ) + return v def __init__( self, @@ -52,23 +61,27 @@ class GreenUC2Agent(AbstractScriptedAgent): # If seed not specified, set it to None so that numpy chooses a random one. settings.setdefault("random_seed") - self.settings = GreenUC2Agent.GreenUC2AgentSettings(settings) + self.settings = ProbabilisticAgent.Settings(**settings) self.rng = np.random.default_rng(self.settings.random_seed) # convert probabilities from - self.probabilities = np.array[self.settings.action_probabilities.values()] + self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) - def get_action(self, obs: ObsType, reward: float = 0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + """ + Choose a random action from the action space. + + The probability of each action is given by the corresponding index in ``self.probabilities``. + + :param obs: Current observation of the simulation + :type obs: ObsType + :param reward: Reward for the last step, not used for scripted agents, defaults to 0 + :type reward: float, optional + :return: Action to be taken in CAOS format. + :rtype: Tuple[str, Dict] + """ choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) return self.action_manager.get_action(choice) - - raise NotImplementedError - - -class RedDatabaseCorruptingAgent(AbstractScriptedAgent): - """Scripted agent which attempts to corrupt the database of the target node.""" - - raise NotImplementedError diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index a9d564ba..b44abe16 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,10 +7,10 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent +from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.agent.scripted_agents import GreenUC2Agent +from primaite.game.agent.scripted_agents import ProbabilisticAgent from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -165,7 +165,7 @@ class PrimaiteGame: for agent in self.agents: obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward - action_choice, options = agent.get_action(obs, rew) + action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter) agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) @@ -393,14 +393,15 @@ class PrimaiteGame: agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) # CREATE AGENT - if agent_type == "GreenUC2Agent": + if agent_type == "probabilistic_agent": # TODO: implement non-random agents and fix this parsing - new_agent = GreenUC2Agent( + settings = agent_cfg.get("agent_settings") + new_agent = ProbabilisticAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=reward_function, - agent_settings=agent_settings, + settings=settings, ) game.agents.append(new_agent) elif agent_type == "ProxyAgent": diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index c4fe4c9a..fa4a28a4 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -334,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [] }, @@ -346,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "tags": [] }, @@ -371,11 +371,150 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-27 09:43:39,312::WARNING::primaite.game.game::275::service type not found DatabaseClient\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting environment, episode 0, avg. reward: 0.0\n", + "env created successfully\n", + "{'ACL': {1: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 0,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 2: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 1,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 3: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 2,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 4: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 3,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 5: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 4,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 6: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 5,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 7: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 6,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 8: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 7,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 9: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 8,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 10: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 9,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0}},\n", + " 'ICS': 0,\n", + " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", + " 2: {'PROTOCOLS': {'ALL': 1}},\n", + " 3: {'PROTOCOLS': {'ALL': 1}},\n", + " 4: {'PROTOCOLS': {'ALL': 1}},\n", + " 5: {'PROTOCOLS': {'ALL': 1}},\n", + " 6: {'PROTOCOLS': {'ALL': 1}},\n", + " 7: {'PROTOCOLS': {'ALL': 1}},\n", + " 8: {'PROTOCOLS': {'ALL': 1}},\n", + " 9: {'PROTOCOLS': {'ALL': 1}},\n", + " 10: {'PROTOCOLS': {'ALL': 0}}},\n", + " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", + " 'health_status': 1}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}}\n" + ] + } + ], "source": [ "# create the env\n", "with open(example_config_path(), 'r') as f:\n", @@ -403,7 +542,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -421,15 +560,57 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 211, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 212, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 213, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 214, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 215, Red action: DO NOTHING, Blue reward:-0.42\n", + "step: 216, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 217, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 218, Red action: DO NOTHING, Blue reward:-0.42\n", + "step: 219, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 220, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 221, Red action: ATTACK from client 2, Blue reward:-0.32\n", + "step: 222, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 223, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 224, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 225, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 226, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 227, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 228, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 229, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 230, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 231, Red action: DO NOTHING, Blue reward:-0.42\n", + "step: 232, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 233, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 234, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 235, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 236, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 237, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 238, Red action: ATTACK from client 2, Blue reward:-0.32\n", + "step: 239, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 240, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 241, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 242, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 243, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 244, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 245, Red action: DO NOTHING, Blue reward:-0.32\n" + ] + } + ], "source": [ "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", - " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" ] }, { @@ -509,7 +690,7 @@ "\n", "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n", "\n", - "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -523,8 +704,8 @@ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", "print(f\"Blue reward:{reward}\" )" ] }, @@ -582,6 +763,33 @@ "obs['ACL']" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = env.game.simulation.network" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dbc = net.get_node_by_hostname('client_1').software_manager.software.get('DatabaseClient')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dbc._query_success_tracker" + ] + }, { "cell_type": "code", "execution_count": null, @@ -606,7 +814,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 5bdc3273..892e6af7 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -21,7 +21,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index d1cec079..ad2ea787 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -33,7 +33,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 8361e318..9b668686 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 87bd9d1c..5a7d8366 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -31,7 +31,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 76190a64..42dd27fb 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -29,7 +29,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5d004c7e..8a4a1178 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/conftest.py b/tests/conftest.py index 5084c339..2add835f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import pytest import yaml @@ -309,7 +309,7 @@ class ControlledAgent(AbstractAgent): ) self.most_recent_action: Tuple[str, Dict] - def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" return self.most_recent_action @@ -478,7 +478,6 @@ def game_and_agent(): ] action_space = ActionManager( - game=game, actions=actions, # ALL POSSIBLE ACTIONS nodes=[ { diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py new file mode 100644 index 00000000..f0b37cac --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -0,0 +1,84 @@ +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents import ProbabilisticAgent + + +def test_probabilistic_agent(): + """ + Check that the probabilistic agent selects actions with approximately the right probabilities. + + Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator), + we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values + were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number + generation. + """ + N_TRIALS = 10_000 + P_DO_NOTHING = 0.1 + P_NODE_APPLICATION_EXECUTE = 0.3 + P_NODE_FILE_DELETE = 0.6 + MIN_DO_NOTHING = 850 + MAX_DO_NOTHING = 1150 + MIN_NODE_APPLICATION_EXECUTE = 2800 + MAX_NODE_APPLICATION_EXECUTE = 3200 + MIN_NODE_FILE_DELETE = 5750 + MAX_NODE_FILE_DELETE = 6250 + + action_space = ActionManager( + actions=[ + {"type": "DONOTHING"}, + {"type": "NODE_APPLICATION_EXECUTE"}, + {"type": "NODE_FILE_DELETE"}, + ], + nodes=[ + { + "node_name": "client_1", + "applications": [{"application_name": "WebBrowser"}], + "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], + }, + ], + max_folders_per_node=2, + max_files_per_folder=2, + max_services_per_node=2, + max_applications_per_node=2, + max_nics_per_node=2, + max_acl_rules=10, + protocols=["TCP", "UDP", "ICMP"], + ports=["HTTP", "DNS", "ARP"], + act_map={ + 0: {"action": "DONOTHING", "options": {}}, + 1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}}, + 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, + }, + ) + observation_space = ObservationManager(ICSObservation()) + reward_function = RewardFunction() + + pa = ProbabilisticAgent( + agent_name="test_agent", + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + settings={ + "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + "random_seed": 120, + }, + ) + + do_nothing_count = 0 + node_application_execute_count = 0 + node_file_delete_count = 0 + for _ in range(N_TRIALS): + a = pa.get_action(0, timestep=0) + if a == ("DONOTHING", {}): + do_nothing_count += 1 + elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): + node_application_execute_count += 1 + elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}): + node_file_delete_count += 1 + else: + raise AssertionError("Probabilistic agent produced an unexpected action.") + + assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING + assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE + assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE