Merged PR 285: Green agent that sometimes performs database connections

## Summary
* removed `GreenWebBrowsingAgent` because it was replaced by probabilistic agent
* created new 'probabilistic agent' which selects actions randomly from its action map, with configurable probabilities
* slightly refactored action manager to decouple it from `PrimaiteGame` (as a consequence, agents should be given the current timestep if their `get_action()` method is time-dependent)
* refactored `data_manipulation_bot` to use an existing db client on the host rather than inheriting from it
* added new type of SQL query to databases: `"SELECT * FROM pg_stat_activity"` to model checking connection status
* added new execution definition on the `DatabaseClient` app which just performs that new SQL query
* added reward for the green admin being able to connect to the db
* updated uc2 notebook to reflect new changes.
* updated documentation for data manipulation bot
* added new test for probabilistic agent
* added test for new reward

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [x] written **tests** for any new functionality added with this PR
- [x] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [x] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Related work items: #2319
This commit is contained in:
Marek Wolan
2024-03-04 12:06:55 +00:00
34 changed files with 649 additions and 203 deletions

View File

@@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
...
- ref: green_agent_example
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:
@@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive
``type``
--------
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour.
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour.
Available agent types:
- ``GreenWebBrowsingAgent``
- ``ProbabilisticAgent``
- ``ProxyAgent``
- ``RedDatabaseCorruptingAgent``

View File

@@ -45,7 +45,7 @@ In a simulation, the bot can be controlled by using ``DataManipulationAgent`` wh
Implementation
==============
The bot extends :ref:`DatabaseClient` and leverages its connectivity.
The bot connects to a :ref:`DatabaseClient` and leverages its connectivity. The host running ``DataManipulationBot`` must also have a :ref:`DatabaseClient` installed on it.
- Uses the Application base class for lifecycle management.
- Credentials, target IP and other options set via ``configure``.
@@ -65,6 +65,7 @@ Python
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.database_client import DatabaseClient
client_1 = Computer(
hostname="client_1",
@@ -74,6 +75,7 @@ Python
operating_state=NodeOperatingState.ON # initialise the computer in an ON state
)
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
@@ -148,6 +150,10 @@ If not using the data manipulation bot manually, it needs to be used with a data
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
Configuration
=============

View File

@@ -25,6 +25,13 @@ Usage
- Clients connect, execute queries, and disconnect.
- Service runs on TCP port 5432 by default.
**Supported queries:**
* ``SELECT``: As long as the database file is in a ``GOOD`` health state, the db service will respond with a 200 status code.
* ``DELETE``: This query represents an attack, it will cause the database file to enter a ``COMPROMISED`` state, and return a status code 200.
* ``INSERT``: If the database service is in a healthy state, this will return a 200 status, if it's not in a healthy state it will return 404.
* ``SELECT * FROM pg_stat_activity``: This query represents something an admin would send to check the status of the database. If the database service is in a healthy state, it returns a 200 status code, otherwise a 401 status code.
Implementation
==============

View File

@@ -33,7 +33,12 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
@@ -45,24 +50,38 @@ agents:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
@@ -74,10 +93,26 @@ agents:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: DUMMY
@@ -85,6 +120,7 @@ agents:
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
@@ -572,6 +608,14 @@ agents:
weight: 0.33
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.1
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.1
options:
node_hostname: client_2
agent_settings:
@@ -722,6 +766,10 @@ simulation:
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: client_1_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
@@ -745,6 +793,10 @@ simulation:
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_2_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient

View File

@@ -35,7 +35,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:
@@ -64,7 +64,7 @@ agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -607,7 +607,6 @@ class ActionManager:
def __init__(
self,
game: "PrimaiteGame", # reference to game for information lookup
actions: List[Dict], # stores list of actions available to agent
nodes: List[Dict], # extra configuration for each node
max_folders_per_node: int = 2, # allows calculating shape
@@ -618,7 +617,7 @@ class ActionManager:
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
"""Init method for ActionManager.
@@ -649,7 +648,6 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.game: "PrimaiteGame" = game
self.node_names: List[str] = [n["node_name"] for n in nodes]
"""List of node names in this action space. The list order is the mapping between node index and node name."""
self.application_names: List[List[str]] = []
@@ -707,25 +705,7 @@ class ActionManager:
self.protocols: List[str] = protocols
self.ports: List[str] = ports
self.ip_address_list: List[str]
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
# the nodes in the simulation.
# TODO: refactor. Options:
# 1: This should be pulled out into it's own function for clarity
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
# go through the nodes here.
if ip_address_list is not None:
self.ip_address_list = ip_address_list
else:
self.ip_address_list = []
for node_name in self.node_names:
node_obj = self.game.simulation.network.get_node_by_hostname(node_name)
if node_obj is None:
continue
network_interfaces = node_obj.network_interfaces
for nic_uuid, nic_obj in network_interfaces.items():
self.ip_address_list.append(nic_obj.ip_address)
self.ip_address_list: List[str] = ip_address_list
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
@@ -832,6 +812,13 @@ class ActionManager:
:return: The node hostname.
:rtype: str
"""
if not node_idx < len(self.node_names):
msg = (
f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
f"has {len(self.node_names)} nodes."
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.node_names[node_idx]
def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
@@ -845,6 +832,13 @@ class ActionManager:
:return: The name of the folder. Or None if the node has fewer folders than the given index.
:rtype: Optional[str]
"""
if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
f" is out of range for its action space. Folder on each node: {self.folder_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.folder_names[node_idx][folder_idx]
def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
@@ -860,6 +854,17 @@ class ActionManager:
fewer files than the given index.
:rtype: Optional[str]
"""
if (
node_idx >= len(self.file_names)
or folder_idx >= len(self.file_names[node_idx])
or file_idx >= len(self.file_names[node_idx][folder_idx])
):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
f" but this is out of range for its action space. Files on each node: {self.file_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.file_names[node_idx][folder_idx][file_idx]
def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
@@ -872,6 +877,13 @@ class ActionManager:
:return: The name of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
f" is out of range for its action space. Services on each node: {self.service_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.service_names[node_idx][service_idx]
def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
@@ -884,6 +896,13 @@ class ActionManager:
:return: The name of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
msg = (
f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
f"this is out of range for its action space. Applications on each node: {self.application_names}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.application_names[node_idx][application_idx]
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
@@ -894,6 +913,13 @@ class ActionManager:
:return: The protocol.
:rtype: str
"""
if protocol_idx >= len(self.protocols):
msg = (
f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
f" is out of range for its action space. Protocols: {self.protocols}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.protocols[protocol_idx]
def get_ip_address_by_idx(self, ip_idx: int) -> str:
@@ -905,6 +931,13 @@ class ActionManager:
:return: The IP address.
:rtype: str
"""
if ip_idx >= len(self.ip_address_list):
msg = (
f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
f" is out of range for its action space. IP address list: {self.ip_address_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.ip_address_list[ip_idx]
def get_port_by_idx(self, port_idx: int) -> str:
@@ -916,6 +949,13 @@ class ActionManager:
:return: The port.
:rtype: str
"""
if port_idx >= len(self.ports):
msg = (
f"Error: agent attempted to perform an action on port {port_idx} but this"
f" is out of range for its action space. Port list: {self.ip_address_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.ports[port_idx]
def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
@@ -958,6 +998,12 @@ class ActionManager:
:return: The constructed ActionManager.
:rtype: ActionManager
"""
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
# the nodes in the simulation.
# TODO: refactor. Options:
# 1: This should be pulled out into it's own function for clarity
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
# go through the nodes here.
ip_address_order = cfg["options"].pop("ip_address_order", {})
ip_address_list = []
for entry in ip_address_order:
@@ -967,13 +1013,22 @@ class ActionManager:
ip_address = node_obj.network_interface[nic_num].ip_address
ip_address_list.append(ip_address)
if not ip_address_list:
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
for node_name in node_names:
node_obj = game.simulation.network.get_node_by_hostname(node_name)
if node_obj is None:
continue
network_interfaces = node_obj.network_interfaces
for nic_uuid, nic_obj in network_interfaces.items():
ip_address_list.append(nic_obj.ip_address)
obj = cls(
game=game,
actions=cfg["action_list"],
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=ip_address_list or None,
ip_address_list=ip_address_list,
act_map=cfg.get("action_map"),
)

View File

@@ -26,22 +26,20 @@ class DataManipulationAgent(AbstractScriptedAgent):
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
:param obs: _description_
:param obs: Current observation for this agent, not used in DataManipulationAgent
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.game.step_counter
if current_timestep < self.next_execution_timestep:
if timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}

View File

@@ -112,7 +112,7 @@ class AbstractAgent(ABC):
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -122,6 +122,8 @@ class AbstractAgent(ABC):
:type obs: ObsType
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
:type reward: float, optional
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
:type timestep: int
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
@@ -144,20 +146,20 @@ class AbstractAgent(ABC):
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
...
pass
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Sample the action space randomly.
:param obs: _description_
:param obs: Current observation for this agent, not used in RandomAgent
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:param timestep: The current simulation timestep, not used in RandomAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())
@@ -183,14 +185,14 @@ class ProxyAgent(AbstractAgent):
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface.
:type timestep: int
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""

View File

@@ -242,6 +242,48 @@ class WebpageUnavailablePenalty(AbstractReward):
return cls(node_hostname=node_hostname)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
def calculate(self, state: Dict) -> float:
"""
Calculate the reward based on current simulation state.
:param state: The current state of the simulation.
:type state: Dict
"""
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
class RewardFunction:
"""Manages the reward function for the agent."""
@@ -250,6 +292,7 @@ class RewardFunction:
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
}
"""List of reward class identifiers."""

View File

@@ -1,14 +1,87 @@
"""Agents with predefined behaviours."""
from typing import Dict, Optional, Tuple
import numpy as np
import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
class GreenWebBrowsingAgent(AbstractScriptedAgent):
"""Scripted agent which attempts to send web requests to a target node."""
class ProbabilisticAgent(AbstractScriptedAgent):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
raise NotImplementedError
class Settings(pydantic.BaseModel):
"""Config schema for Probabilistic agent settings."""
model_config = pydantic.ConfigDict(extra="forbid")
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
class RedDatabaseCorruptingAgent(AbstractScriptedAgent):
"""Scripted agent which attempts to corrupt the database of the target node."""
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
"""Make sure the probabilities sum to 1."""
if not abs(sum(v.values()) - 1) < 1e-6:
raise ValueError("Green action probabilities must sum to 1")
return v
raise NotImplementedError
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]:
"""Ensure that the keys of the probability dictionary cover all integers from 0 to N."""
if not all((i in v) for i in range(len(v))):
raise ValueError(
"Green action probabilities must be defined as a mapping where the keys are consecutive integers "
"from 0 to N."
)
return v
def __init__(
self,
agent_name: str,
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
settings: Dict = {},
) -> None:
# If the action probabilities are not specified, create equal probabilities for all actions
if "action_probabilities" not in settings:
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Sample the action space randomly.
The probability of each action is given by the corresponding index in ``self.probabilities``.
:param obs: Current observation for this agent, not used in ProbabilisticAgent
:type obs: ObsType
:param timestep: The current simulation timestep, not used in ProbabilisticAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
return self.action_manager.get_action(choice)

View File

@@ -7,9 +7,10 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -164,8 +165,7 @@ class PrimaiteGame:
agent_actions = {}
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
@@ -299,7 +299,7 @@ class PrimaiteGame:
if service_type == "DatabaseService":
if "options" in service_cfg:
opt = service_cfg["options"]
new_service.password = opt.get("backup_server_ip", None)
new_service.password = opt.get("db_password", None)
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
if service_type == "FTPServer":
if "options" in service_cfg:
@@ -412,20 +412,19 @@ class PrimaiteGame:
# CREATE REWARD FUNCTION
reward_function = RewardFunction.from_config(reward_function_cfg)
# OTHER AGENT SETTINGS
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
new_agent = RandomAgent(
settings = agent_cfg.get("agent_settings")
new_agent = ProbabilisticAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
settings=settings,
)
elif agent_type == "ProxyAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
@@ -435,6 +434,8 @@ class PrimaiteGame:
)
game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,

View File

@@ -13,7 +13,7 @@
"source": [
"## Scenario\n",
"\n",
"The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n",
"The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database. Occasionally, admins need to access the database directly from the clients.\n",
"\n",
"[<img src=\"_package_data/uc2_network.png\" width=\"500\"/>](_package_data/uc2_network.png)\n",
"\n",
@@ -46,7 +46,9 @@
"source": [
"## Green agent\n",
"\n",
"There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available."
"There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available.\n",
"\n",
"Sometimes, the green agents send a request directly to the database to check that it is reachable."
]
},
{
@@ -68,7 +70,9 @@
"source": [
"## Blue agent\n",
"\n",
"The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router."
"The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router.\n",
"\n",
"However, these rules will also impact greens' ability to check the database connection. The blue agent should only block the infected client, it should let the other client connect freely. Once the attack has begun, automated traffic monitoring will detect it as suspicious network traffic. The blue agent's observation space will show this as an increase in the number of malicious network events (NMNE) on one of the network interfaces. To achieve optimal reward, the agent should only block the client which has the non-zero outbound NMNE."
]
},
{
@@ -101,9 +105,11 @@
"The red agent does not use information about the state of the network to decide its action.\n",
"\n",
"### Green\n",
"The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n",
"The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, it will do nothing 30% of the time, send a web request 60% of the time, and send a db status check 10% of the time.\n",
"\n",
"When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender."
"When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender.\n",
"\n",
"Also, when the green agent is blocked from checking the database status, it causes a small negative reward."
]
},
{
@@ -322,9 +328,10 @@
"source": [
"## Reward Function\n",
"\n",
"The blue agent's reward is calculated using two measures:\n",
"The blue agent's reward is calculated using these measures:\n",
"1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n",
"2. Whether each green agents' most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n",
"3. Whether each green agents' most recent DB status check was successful (+1 for a successful connection, -1 for no connection).\n",
"\n",
"The file status reward and the two green-agent-related rewards are averaged to get a total step reward.\n"
]
@@ -346,7 +353,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
@@ -356,7 +365,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Imports\n",
@@ -403,7 +414,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -1.0 when green agents try to access the webpage."
"The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -0.8 when green agents try to access the webpage."
]
},
{
@@ -432,7 +443,7 @@
"source": [
"for step in range(35):\n",
" obs, reward, terminated, truncated, info = env.step(0)\n",
" print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )"
]
},
{
@@ -477,6 +488,13 @@
"File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent."
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -506,7 +524,7 @@
"\n",
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
"\n",
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
]
},
{
@@ -518,9 +536,9 @@
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
"print(f\"Blue reward:{reward}\" )"
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
"print(f\"Blue reward:{reward:.2f}\" )"
]
},
{
@@ -539,24 +557,24 @@
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(26) # Block client 1\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(27) # Block client 2\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"\n",
"for step in range(30):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, even though the red agent executes an attack, the reward stays at 1.0"
"Now, even though the red agent executes an attack, the reward stays at 0.8."
]
},
{
@@ -575,6 +593,46 @@
"obs['ACL']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can slightly increase the reward by unblocking the client which isn't being used by the attacker. If node 6 has outbound NMNEs, let's unblock client 2, and if node 7 has outbound NMNEs, let's unblock client 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's unblock client 2\n",
" env.step(34) # remove ACL rule 6\n",
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" env.step(33) # remove ACL rule 5\n",
"else:\n",
" print(\"something went wrong, neither client has NMNEs\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for step in range(30):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -590,13 +648,6 @@
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -1,3 +1,4 @@
import copy
import json
from typing import Any, Dict, Optional, SupportsFloat, Tuple
@@ -23,7 +24,7 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
@@ -78,7 +79,7 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()

View File

@@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
@@ -40,7 +39,7 @@ class SessionMode(Enum):
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game: PrimaiteGame):
def __init__(self, game_cfg: Dict):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
@@ -57,7 +56,7 @@ class PrimaiteSession:
self.io_manager: Optional["SessionIO"] = None
"""IO manager for the session."""
self.game: PrimaiteGame = game
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
def start_session(self) -> None:
@@ -93,9 +92,7 @@ class PrimaiteSession:
io_settings = cfg.get("io_settings", {})
io_manager = SessionIO(SessionIOSettings(**io_settings))
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])

View File

@@ -146,6 +146,9 @@ def arcd_uc2_network() -> Network:
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
db_client_1 = client_1.software_manager.install(DatabaseClient)
db_client_1 = client_1.software_manager.software.get("DatabaseClient")
db_client_1.run()
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
db_manipulation_bot.configure(
@@ -165,6 +168,9 @@ def arcd_uc2_network() -> Network:
start_up_duration=0,
)
client_2.power_on()
client_2.software_manager.install(DatabaseClient)
db_client_2 = client_2.software_manager.software.get("DatabaseClient")
db_client_2.run()
web_browser = client_2.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
@@ -194,67 +200,10 @@ def arcd_uc2_network() -> Network:
database_server.power_on()
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3])
ddl = """
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(50) NOT NULL,
email VARCHAR(50) NOT NULL,
age INT,
city VARCHAR(50),
occupation VARCHAR(50)
);"""
user_insert_statements = [
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');",
# noqa
]
database_server.software_manager.install(DatabaseService)
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
database_service.start()
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
database_service._process_sql(ddl, None, None) # noqa
for insert_statement in user_insert_statements:
database_service._process_sql(insert_statement, None, None) # noqa
# Web Server
web_server = Server(

View File

@@ -15,6 +15,8 @@ class IPProtocol(Enum):
.. _List of IPProtocols:
"""
NONE = "none"
"""Placeholder for a non-protocol."""
TCP = "tcp"
"""Transmission Control Protocol."""
UDP = "udp"

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
from uuid import uuid4
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
@@ -25,6 +26,8 @@ class DatabaseClient(Application):
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
_last_connection_successful: Optional[bool] = None
"""Keep track of connections that were established or verified during this step. Used for rewards."""
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
@@ -32,14 +35,30 @@ class DatabaseClient(Application):
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("execute", RequestType(func=lambda request, context: self.execute()))
return rm
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
if self.connections:
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
else:
can_connect = self.check_connection(connection_id=str(uuid4()))
self._last_connection_successful = can_connect
return can_connect
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
:return: A dictionary representing the current state.
"""
pass
return super().describe_state()
state = super().describe_state()
# list of connections that were established or verified during this step.
state["last_connection_successful"] = self._last_connection_successful
return state
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
"""
@@ -65,6 +84,18 @@ class DatabaseClient(Application):
)
return self.connected
def check_connection(self, connection_id: str) -> bool:
"""Check whether the connection can be successfully re-established.
:param connection_id: connection ID to check
:type connection_id: str
:return: Whether the connection was successfully re-established.
:rtype: bool
"""
if not self._can_perform_action():
return False
return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _connect(
self,
server_ip_address: IPv4Address,

View File

@@ -1,10 +1,13 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from typing import Dict, Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
@@ -32,12 +35,10 @@ class DataManipulationAttackStage(IntEnum):
"Signifies that the attack has failed."
class DataManipulationBot(DatabaseClient):
class DataManipulationBot(Application):
"""A bot that simulates a script which performs a SQL injection attack."""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
@@ -46,8 +47,31 @@ class DataManipulationBot(DatabaseClient):
"Whether to repeat attacking once finished."
def __init__(self, **kwargs):
kwargs["name"] = "DataManipulationBot"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
return state
@property
def _host_db_client(self) -> DatabaseClient:
"""Return the database client that is installed on the same machine as the DataManipulationBot."""
db_client = self.software_manager.software.get("DatabaseClient")
if db_client is None:
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
return db_client
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -76,8 +100,8 @@ class DataManipulationBot(DatabaseClient):
:param repeat: Whether to repeat attacking once finished.
"""
self.server_ip_address = server_ip_address
self.payload = payload
self.server_password = server_password
self.payload = payload
self.port_scan_p_of_success = port_scan_p_of_success
self.data_manipulation_p_of_success = data_manipulation_p_of_success
self.repeat = repeat
@@ -123,15 +147,21 @@ class DataManipulationBot(DatabaseClient):
:param p_of_success: Probability of successfully performing data manipulation, by default 0.1.
"""
if self._host_db_client is None:
self.attack_stage = DataManipulationAttackStage.FAILED
return
self._host_db_client.server_ip_address = self.server_ip_address
self._host_db_client.server_password = self.server_password
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
# perform the actual data manipulation attack
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing data manipulation")
# perform the attack
if not len(self.connections):
self.connect()
if len(self.connections):
self.query(self.payload)
if not len(self._host_db_client.connections):
self._host_db_client.connect()
if len(self._host_db_client.connections):
self._host_db_client.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
attack_successful = True
if attack_successful:

View File

@@ -6,7 +6,6 @@ from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
@@ -28,7 +27,7 @@ class DoSAttackStage(IntEnum):
"Attack is completed."
class DoSBot(DatabaseClient, Application):
class DoSBot(DatabaseClient):
"""A bot that simulates a Denial of Service attack."""
target_ip_address: Optional[IPv4Address] = None

View File

@@ -199,7 +199,7 @@ class WebBrowser(Application):
def state(self) -> Dict:
"""Return the contents of this dataclass as a dict for use with describe_state method."""
if self.status == self._HistoryItemStatus.LOADED:
outcome = self.response_code
outcome = self.response_code.value
else:
outcome = self.status.value
return {"url": self.url, "outcome": outcome}

View File

@@ -221,6 +221,18 @@ class DatabaseService(Service):
}
else:
return {"status_code": 404, "data": False}
elif query == "SELECT * FROM pg_stat_activity":
# Check if the connection is active.
if self.health_state_actual == SoftwareHealthState.GOOD:
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
else:
return {"status_code": 401, "data": False}
else:
# Invalid query
self.sys_log.info(f"{self.name}: Invalid {query}")

View File

@@ -21,7 +21,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -40,7 +40,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -40,7 +40,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -65,7 +65,7 @@ game:
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -25,7 +25,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -31,7 +31,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -29,7 +29,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -25,7 +25,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import pytest
import yaml
@@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent):
)
self.most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]:
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action
@@ -497,7 +497,6 @@ def game_and_agent():
]
action_space = ActionManager(
game=game,
actions=actions, # ALL POSSIBLE ACTIONS
nodes=[
{

View File

@@ -21,15 +21,15 @@ class TestPrimaiteSession:
raise AssertionError
assert session is not None
assert session.game.simulation
assert len(session.game.agents) == 3
assert len(session.game.rl_agents) == 1
assert session.env.game.simulation
assert len(session.env.game.agents) == 3
assert len(session.env.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.game.simulation.network
assert len(session.game.simulation.network.nodes) == 10
assert session.env.game.simulation.network
assert len(session.env.game.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):

View File

@@ -1,7 +1,10 @@
from primaite.game.agent.rewards import WebpageUnavailablePenalty
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
from tests.conftest import ControlledAgent
@@ -35,3 +38,45 @@ def test_WebpageUnavailablePenalty(game_and_agent):
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.7
def test_uc2_rewards(game_and_agent):
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
game, agent = game_and_agent
agent: ControlledAgent
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
server_1.software_manager.install(DatabaseService)
db_service = server_1.software_manager.software.get("DatabaseService")
db_service.start()
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
db_client.run()
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
assert reward_value == 1.0
router.acl.remove_rule(position=2)
db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
assert reward_value == -1.0

View File

@@ -0,0 +1,84 @@
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
def test_probabilistic_agent():
"""
Check that the probabilistic agent selects actions with approximately the right probabilities.
Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator),
we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values
were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number
generation.
"""
N_TRIALS = 10_000
P_DO_NOTHING = 0.1
P_NODE_APPLICATION_EXECUTE = 0.3
P_NODE_FILE_DELETE = 0.6
MIN_DO_NOTHING = 850
MAX_DO_NOTHING = 1150
MIN_NODE_APPLICATION_EXECUTE = 2800
MAX_NODE_APPLICATION_EXECUTE = 3200
MIN_NODE_FILE_DELETE = 5750
MAX_NODE_FILE_DELETE = 6250
action_space = ActionManager(
actions=[
{"type": "DONOTHING"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_DELETE"},
],
nodes=[
{
"node_name": "client_1",
"applications": [{"application_name": "WebBrowser"}],
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=2,
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
act_map={
0: {"action": "DONOTHING", "options": {}},
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
},
)
observation_space = ObservationManager(ICSObservation())
reward_function = RewardFunction()
pa = ProbabilisticAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
settings={
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
"random_seed": 120,
},
)
do_nothing_count = 0
node_application_execute_count = 0
node_file_delete_count = 0
for _ in range(N_TRIALS):
a = pa.get_action(0)
if a == ("DONOTHING", {}):
do_nothing_count += 1
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
node_application_execute_count += 1
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
node_file_delete_count += 1
else:
raise AssertionError("Probabilistic agent produced an unexpected action.")
assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING
assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE
assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE

View File

@@ -26,8 +26,8 @@ def test_create_dm_bot(dm_client):
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot")
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.port == Port.NONE
assert data_manipulation_bot.protocol == IPProtocol.NONE
assert data_manipulation_bot.payload == "DELETE"
@@ -70,4 +70,13 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
assert len(dm_bot.connections)
assert len(dm_bot._host_db_client.connections)
def test_dm_bot_fails_without_db_client(dm_client):
dm_client.software_manager.uninstall("DatabaseClient")
dm_bot = dm_client.software_manager.software.get("DataManipulationBot")
assert dm_bot._host_db_client is None
dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage is DataManipulationAttackStage.FAILED