Get the db admin green agent working

This commit is contained in:
Marek Wolan
2024-02-27 13:30:16 +00:00
parent c54f82fb1b
commit af8ca82fcb
16 changed files with 386 additions and 87 deletions

View File

@@ -33,7 +33,12 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenUC2Agent
type: probabilistic_agent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
@@ -69,15 +74,14 @@ agents:
reward_components:
- type: DUMMY
- ref: client_1_green_user
team: GREEN
type: probabilistic_agent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
- ref: client_1_green_user
team: GREEN
type: GreenUC2Agent
observation_space:
type: UC2GreenObservation
action_space:
@@ -113,11 +117,6 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1

View File

@@ -27,7 +27,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -607,7 +607,6 @@ class ActionManager:
def __init__(
self,
game: "PrimaiteGame", # reference to game for information lookup
actions: List[Dict], # stores list of actions available to agent
nodes: List[Dict], # extra configuration for each node
max_folders_per_node: int = 2, # allows calculating shape
@@ -618,7 +617,7 @@ class ActionManager:
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
"""Init method for ActionManager.
@@ -649,7 +648,6 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.game: "PrimaiteGame" = game
self.node_names: List[str] = [n["node_name"] for n in nodes]
"""List of node names in this action space. The list order is the mapping between node index and node name."""
self.application_names: List[List[str]] = []
@@ -707,25 +705,7 @@ class ActionManager:
self.protocols: List[str] = protocols
self.ports: List[str] = ports
self.ip_address_list: List[str]
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
# the nodes in the simulation.
# TODO: refactor. Options:
# 1: This should be pulled out into it's own function for clarity
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
# go through the nodes here.
if ip_address_list is not None:
self.ip_address_list = ip_address_list
else:
self.ip_address_list = []
for node_name in self.node_names:
node_obj = self.game.simulation.network.get_node_by_hostname(node_name)
if node_obj is None:
continue
network_interfaces = node_obj.network_interfaces
for nic_uuid, nic_obj in network_interfaces.items():
self.ip_address_list.append(nic_obj.ip_address)
self.ip_address_list: List[str] = ip_address_list
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
@@ -958,6 +938,12 @@ class ActionManager:
:return: The constructed ActionManager.
:rtype: ActionManager
"""
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
# the nodes in the simulation.
# TODO: refactor. Options:
# 1: This should be pulled out into it's own function for clarity
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
# go through the nodes here.
ip_address_order = cfg["options"].pop("ip_address_order", {})
ip_address_list = []
for entry in ip_address_order:
@@ -967,13 +953,22 @@ class ActionManager:
ip_address = node_obj.network_interface[nic_num].ip_address
ip_address_list.append(ip_address)
if not ip_address_list:
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
for node_name in node_names:
node_obj = game.simulation.network.get_node_by_hostname(node_name)
if node_obj is None:
continue
network_interfaces = node_obj.network_interfaces
for nic_uuid, nic_obj in network_interfaces.items():
ip_address_list.append(nic_obj.ip_address)
obj = cls(
game=game,
actions=cfg["action_list"],
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=ip_address_list or None,
ip_address_list=ip_address_list,
act_map=cfg.get("action_map"),
)

View File

@@ -1,5 +1,5 @@
import random
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
@@ -26,7 +26,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -36,12 +36,10 @@ class DataManipulationAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.game.step_counter
if current_timestep < self.next_execution_timestep:
if timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}

View File

@@ -112,7 +112,7 @@ class AbstractAgent(ABC):
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -122,6 +122,8 @@ class AbstractAgent(ABC):
:type obs: ObsType
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
:type reward: float, optional
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
:type timestep: int
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
@@ -144,13 +146,13 @@ class AbstractAgent(ABC):
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
...
pass
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -183,7 +185,7 @@ class ProxyAgent(AbstractAgent):
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.

View File

@@ -11,30 +11,39 @@ from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
class GreenUC2Agent(AbstractScriptedAgent):
"""Scripted agent which attempts to send web requests to a target node."""
class ProbabilisticAgent(AbstractScriptedAgent):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
class Settings(pydantic.BaseModel):
"""Config schema for Probabilistic agent settings."""
class GreenUC2AgentSettings(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
"""Make sure the probabilities sum to 1."""
if not abs(sum(v.values()) - 1) < 1e-6:
raise ValueError(f"Green action probabilities must sum to 1")
raise ValueError("Green action probabilities must sum to 1")
return v
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]:
"""Ensure that the keys of the probability dictionary cover all integers from 0 to N."""
if not all((i in v) for i in range(len(v))):
raise ValueError(
"Green action probabilities must be defined as a mapping where the keys are consecutive integers "
"from 0 to N."
)
return v
def __init__(
self,
@@ -52,23 +61,27 @@ class GreenUC2Agent(AbstractScriptedAgent):
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
self.settings = GreenUC2Agent.GreenUC2AgentSettings(settings)
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
# convert probabilities from
self.probabilities = np.array[self.settings.action_probabilities.values()]
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
def get_action(self, obs: ObsType, reward: float = 0) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""
Choose a random action from the action space.
The probability of each action is given by the corresponding index in ``self.probabilities``.
:param obs: Current observation of the simulation
:type obs: ObsType
:param reward: Reward for the last step, not used for scripted agents, defaults to 0
:type reward: float, optional
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
return self.action_manager.get_action(choice)
raise NotImplementedError
class RedDatabaseCorruptingAgent(AbstractScriptedAgent):
"""Scripted agent which attempts to corrupt the database of the target node."""
raise NotImplementedError

View File

@@ -7,10 +7,10 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import GreenUC2Agent
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -165,7 +165,7 @@ class PrimaiteGame:
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
@@ -393,14 +393,15 @@ class PrimaiteGame:
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
# CREATE AGENT
if agent_type == "GreenUC2Agent":
if agent_type == "probabilistic_agent":
# TODO: implement non-random agents and fix this parsing
new_agent = GreenUC2Agent(
settings = agent_cfg.get("agent_settings")
new_agent = ProbabilisticAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
settings=settings,
)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":

View File

@@ -334,7 +334,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"tags": []
},
@@ -346,7 +346,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"tags": []
},
@@ -371,11 +371,150 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-02-27 09:43:39,312::WARNING::primaite.game.game::275::service type not found DatabaseClient\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Resetting environment, episode 0, avg. reward: 0.0\n",
"env created successfully\n",
"{'ACL': {1: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 0,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 2: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 1,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 3: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 2,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 4: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 3,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 5: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 4,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 6: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 5,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 7: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 6,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 8: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 7,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 9: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 8,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 10: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 9,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0}},\n",
" 'ICS': 0,\n",
" 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n",
" 2: {'PROTOCOLS': {'ALL': 1}},\n",
" 3: {'PROTOCOLS': {'ALL': 1}},\n",
" 4: {'PROTOCOLS': {'ALL': 1}},\n",
" 5: {'PROTOCOLS': {'ALL': 1}},\n",
" 6: {'PROTOCOLS': {'ALL': 1}},\n",
" 7: {'PROTOCOLS': {'ALL': 1}},\n",
" 8: {'PROTOCOLS': {'ALL': 1}},\n",
" 9: {'PROTOCOLS': {'ALL': 1}},\n",
" 10: {'PROTOCOLS': {'ALL': 0}}},\n",
" 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n",
" 'operating_status': 1},\n",
" 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n",
" 'operating_status': 1},\n",
" 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n",
" 'health_status': 1}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n",
" 2: {'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1}}}\n"
]
}
],
"source": [
"# create the env\n",
"with open(example_config_path(), 'r') as f:\n",
@@ -403,7 +542,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -421,15 +560,57 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step: 211, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 212, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 213, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 214, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 215, Red action: DO NOTHING, Blue reward:-0.42\n",
"step: 216, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 217, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 218, Red action: DO NOTHING, Blue reward:-0.42\n",
"step: 219, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 220, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 221, Red action: ATTACK from client 2, Blue reward:-0.32\n",
"step: 222, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 223, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 224, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 225, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 226, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 227, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 228, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 229, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 230, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 231, Red action: DO NOTHING, Blue reward:-0.42\n",
"step: 232, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 233, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 234, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 235, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 236, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 237, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 238, Red action: ATTACK from client 2, Blue reward:-0.32\n",
"step: 239, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 240, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 241, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 242, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 243, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 244, Red action: DO NOTHING, Blue reward:-0.32\n",
"step: 245, Red action: DO NOTHING, Blue reward:-0.32\n"
]
}
],
"source": [
"for step in range(35):\n",
" obs, reward, terminated, truncated, info = env.step(0)\n",
" print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )"
]
},
{
@@ -509,7 +690,7 @@
"\n",
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
"\n",
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
]
},
{
@@ -523,8 +704,8 @@
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
"print(f\"Blue reward:{reward}\" )"
]
},
@@ -582,6 +763,33 @@
"obs['ACL']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = env.game.simulation.network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dbc = net.get_node_by_hostname('client_1').software_manager.software.get('DatabaseClient')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dbc._query_success_tracker"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -606,7 +814,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -21,7 +21,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -33,7 +33,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -25,7 +25,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -31,7 +31,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -29,7 +29,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -25,7 +25,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import pytest
import yaml
@@ -309,7 +309,7 @@ class ControlledAgent(AbstractAgent):
)
self.most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action
@@ -478,7 +478,6 @@ def game_and_agent():
]
action_space = ActionManager(
game=game,
actions=actions, # ALL POSSIBLE ACTIONS
nodes=[
{

View File

@@ -0,0 +1,84 @@
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
def test_probabilistic_agent():
"""
Check that the probabilistic agent selects actions with approximately the right probabilities.
Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator),
we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values
were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number
generation.
"""
N_TRIALS = 10_000
P_DO_NOTHING = 0.1
P_NODE_APPLICATION_EXECUTE = 0.3
P_NODE_FILE_DELETE = 0.6
MIN_DO_NOTHING = 850
MAX_DO_NOTHING = 1150
MIN_NODE_APPLICATION_EXECUTE = 2800
MAX_NODE_APPLICATION_EXECUTE = 3200
MIN_NODE_FILE_DELETE = 5750
MAX_NODE_FILE_DELETE = 6250
action_space = ActionManager(
actions=[
{"type": "DONOTHING"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_DELETE"},
],
nodes=[
{
"node_name": "client_1",
"applications": [{"application_name": "WebBrowser"}],
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=2,
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
act_map={
0: {"action": "DONOTHING", "options": {}},
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
},
)
observation_space = ObservationManager(ICSObservation())
reward_function = RewardFunction()
pa = ProbabilisticAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
settings={
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
"random_seed": 120,
},
)
do_nothing_count = 0
node_application_execute_count = 0
node_file_delete_count = 0
for _ in range(N_TRIALS):
a = pa.get_action(0, timestep=0)
if a == ("DONOTHING", {}):
do_nothing_count += 1
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
node_application_execute_count += 1
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
node_file_delete_count += 1
else:
raise AssertionError("Probabilistic agent produced an unexpected action.")
assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING
assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE
assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE