Merged PR 285: Green agent that sometimes performs database connections
## Summary * removed `GreenWebBrowsingAgent` because it was replaced by probabilistic agent * created new 'probabilistic agent' which selects actions randomly from its action map, with configurable probabilities * slightly refactored action manager to decouple it from `PrimaiteGame` (as a consequence, agents should be given the current timestep if their `get_action()` method is time-dependent) * refactored `data_manipulation_bot` to use an existing db client on the host rather than inheriting from it * added new type of SQL query to databases: `"SELECT * FROM pg_stat_activity"` to model checking connection status * added new execution definition on the `DatabaseClient` app which just performs that new SQL query * added reward for the green admin being able to connect to the db * updated uc2 notebook to reflect new changes. * updated documentation for data manipulation bot * added new test for probabilistic agent * added test for new reward ## Checklist - [x] PR is linked to a **work item** - [x] **acceptance criteria** of linked ticket are met - [x] performed **self-review** of the code - [x] written **tests** for any new functionality added with this PR - [x] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [x] updated the **change log** - [x] ran **pre-commit** checks for code style - [x] attended to any **TO-DOs** left in the code Related work items: #2319
This commit is contained in:
@@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
|
||||
...
|
||||
- ref: green_agent_example
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
@@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive
|
||||
|
||||
``type``
|
||||
--------
|
||||
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour.
|
||||
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour.
|
||||
|
||||
Available agent types:
|
||||
|
||||
- ``GreenWebBrowsingAgent``
|
||||
- ``ProbabilisticAgent``
|
||||
- ``ProxyAgent``
|
||||
- ``RedDatabaseCorruptingAgent``
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ In a simulation, the bot can be controlled by using ``DataManipulationAgent`` wh
|
||||
Implementation
|
||||
==============
|
||||
|
||||
The bot extends :ref:`DatabaseClient` and leverages its connectivity.
|
||||
The bot connects to a :ref:`DatabaseClient` and leverages its connectivity. The host running ``DataManipulationBot`` must also have a :ref:`DatabaseClient` installed on it.
|
||||
|
||||
- Uses the Application base class for lifecycle management.
|
||||
- Credentials, target IP and other options set via ``configure``.
|
||||
@@ -65,6 +65,7 @@ Python
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
client_1 = Computer(
|
||||
hostname="client_1",
|
||||
@@ -74,6 +75,7 @@ Python
|
||||
operating_state=NodeOperatingState.ON # initialise the computer in an ON state
|
||||
)
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
@@ -148,6 +150,10 @@ If not using the data manipulation bot manually, it needs to be used with a data
|
||||
data_manipulation_p_of_success: 0.1
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
Configuration
|
||||
=============
|
||||
|
||||
@@ -25,6 +25,13 @@ Usage
|
||||
- Clients connect, execute queries, and disconnect.
|
||||
- Service runs on TCP port 5432 by default.
|
||||
|
||||
**Supported queries:**
|
||||
|
||||
* ``SELECT``: As long as the database file is in a ``GOOD`` health state, the db service will respond with a 200 status code.
|
||||
* ``DELETE``: This query represents an attack, it will cause the database file to enter a ``COMPROMISED`` state, and return a status code 200.
|
||||
* ``INSERT``: If the database service is in a healthy state, this will return a 200 status, if it's not in a healthy state it will return 404.
|
||||
* ``SELECT * FROM pg_stat_activity``: This query represents something an admin would send to check the status of the database. If the database service is in a healthy state, it returns a 200 status code, otherwise a 401 status code.
|
||||
|
||||
Implementation
|
||||
==============
|
||||
|
||||
|
||||
@@ -33,7 +33,12 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
@@ -45,24 +50,38 @@ agents:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
@@ -74,10 +93,26 @@ agents:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
@@ -85,6 +120,7 @@ agents:
|
||||
|
||||
|
||||
|
||||
|
||||
- ref: data_manipulation_attacker
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
@@ -572,6 +608,14 @@ agents:
|
||||
weight: 0.33
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.1
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.1
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
|
||||
agent_settings:
|
||||
@@ -722,6 +766,10 @@ simulation:
|
||||
type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- ref: client_1_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
@@ -745,6 +793,10 @@ simulation:
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: client_2_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
@@ -35,7 +35,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
@@ -64,7 +64,7 @@ agents:
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -607,7 +607,6 @@ class ActionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
game: "PrimaiteGame", # reference to game for information lookup
|
||||
actions: List[Dict], # stores list of actions available to agent
|
||||
nodes: List[Dict], # extra configuration for each node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
@@ -618,7 +617,7 @@ class ActionManager:
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
@@ -649,7 +648,6 @@ class ActionManager:
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.node_names: List[str] = [n["node_name"] for n in nodes]
|
||||
"""List of node names in this action space. The list order is the mapping between node index and node name."""
|
||||
self.application_names: List[List[str]] = []
|
||||
@@ -707,25 +705,7 @@ class ActionManager:
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str]
|
||||
|
||||
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
|
||||
# the nodes in the simulation.
|
||||
# TODO: refactor. Options:
|
||||
# 1: This should be pulled out into it's own function for clarity
|
||||
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
|
||||
# go through the nodes here.
|
||||
if ip_address_list is not None:
|
||||
self.ip_address_list = ip_address_list
|
||||
else:
|
||||
self.ip_address_list = []
|
||||
for node_name in self.node_names:
|
||||
node_obj = self.game.simulation.network.get_node_by_hostname(node_name)
|
||||
if node_obj is None:
|
||||
continue
|
||||
network_interfaces = node_obj.network_interfaces
|
||||
for nic_uuid, nic_obj in network_interfaces.items():
|
||||
self.ip_address_list.append(nic_obj.ip_address)
|
||||
self.ip_address_list: List[str] = ip_address_list
|
||||
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
@@ -832,6 +812,13 @@ class ActionManager:
|
||||
:return: The node hostname.
|
||||
:rtype: str
|
||||
"""
|
||||
if not node_idx < len(self.node_names):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
|
||||
f"has {len(self.node_names)} nodes."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.node_names[node_idx]
|
||||
|
||||
def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
@@ -845,6 +832,13 @@ class ActionManager:
|
||||
:return: The name of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
|
||||
f" is out of range for its action space. Folder on each node: {self.folder_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.folder_names[node_idx][folder_idx]
|
||||
|
||||
def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
@@ -860,6 +854,17 @@ class ActionManager:
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if (
|
||||
node_idx >= len(self.file_names)
|
||||
or folder_idx >= len(self.file_names[node_idx])
|
||||
or file_idx >= len(self.file_names[node_idx][folder_idx])
|
||||
):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
|
||||
f" but this is out of range for its action space. Files on each node: {self.file_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.file_names[node_idx][folder_idx][file_idx]
|
||||
|
||||
def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
@@ -872,6 +877,13 @@ class ActionManager:
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
|
||||
f" is out of range for its action space. Services on each node: {self.service_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.service_names[node_idx][service_idx]
|
||||
|
||||
def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
|
||||
@@ -884,6 +896,13 @@ class ActionManager:
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
|
||||
f"this is out of range for its action space. Applications on each node: {self.application_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.application_names[node_idx][application_idx]
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
@@ -894,6 +913,13 @@ class ActionManager:
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
if protocol_idx >= len(self.protocols):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
|
||||
f" is out of range for its action space. Protocols: {self.protocols}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
@@ -905,6 +931,13 @@ class ActionManager:
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
if ip_idx >= len(self.ip_address_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
|
||||
f" is out of range for its action space. IP address list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
@@ -916,6 +949,13 @@ class ActionManager:
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
if port_idx >= len(self.ports):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on port {port_idx} but this"
|
||||
f" is out of range for its action space. Port list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
|
||||
@@ -958,6 +998,12 @@ class ActionManager:
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
|
||||
# the nodes in the simulation.
|
||||
# TODO: refactor. Options:
|
||||
# 1: This should be pulled out into it's own function for clarity
|
||||
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
|
||||
# go through the nodes here.
|
||||
ip_address_order = cfg["options"].pop("ip_address_order", {})
|
||||
ip_address_list = []
|
||||
for entry in ip_address_order:
|
||||
@@ -967,13 +1013,22 @@ class ActionManager:
|
||||
ip_address = node_obj.network_interface[nic_num].ip_address
|
||||
ip_address_list.append(ip_address)
|
||||
|
||||
if not ip_address_list:
|
||||
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
|
||||
for node_name in node_names:
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_name)
|
||||
if node_obj is None:
|
||||
continue
|
||||
network_interfaces = node_obj.network_interfaces
|
||||
for nic_uuid, nic_obj in network_interfaces.items():
|
||||
ip_address_list.append(nic_obj.ip_address)
|
||||
|
||||
obj = cls(
|
||||
game=game,
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=ip_address_list or None,
|
||||
ip_address_list=ip_address_list,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
|
||||
@@ -26,22 +26,20 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
|
||||
|
||||
:param obs: _description_
|
||||
:param obs: Current observation for this agent, not used in DataManipulationAgent
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
current_timestep = self.action_manager.game.step_counter
|
||||
|
||||
if current_timestep < self.next_execution_timestep:
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {"dummy": 0}
|
||||
|
||||
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
return self.reward_function.update(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
@@ -122,6 +122,8 @@ class AbstractAgent(ABC):
|
||||
:type obs: ObsType
|
||||
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
|
||||
:type reward: float, optional
|
||||
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
|
||||
:type timestep: int
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
@@ -144,20 +146,20 @@ class AbstractAgent(ABC):
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: _description_
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:param timestep: The current simulation timestep, not used in RandomAgent
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
@@ -183,14 +185,14 @@ class ProxyAgent(AbstractAgent):
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
|
||||
:type reward: float, optional
|
||||
:param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface.
|
||||
:type timestep: int
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
|
||||
@@ -242,6 +242,48 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname)
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
def __init__(self, node_hostname: str) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
last_connection_successful = db_state["last_connection_successful"]
|
||||
if last_connection_successful is False:
|
||||
return -1.0
|
||||
elif last_connection_successful is True:
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
"""
|
||||
Build the reward component object from config.
|
||||
|
||||
:param config: Configuration dictionary.
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
return cls(node_hostname=node_hostname)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
@@ -250,6 +292,7 @@ class RewardFunction:
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
|
||||
@@ -1,14 +1,87 @@
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GreenWebBrowsingAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to send web requests to a target node."""
|
||||
class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
raise NotImplementedError
|
||||
class Settings(pydantic.BaseModel):
|
||||
"""Config schema for Probabilistic agent settings."""
|
||||
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
random_seed: Optional[int] = None
|
||||
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
class RedDatabaseCorruptingAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to corrupt the database of the target node."""
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
|
||||
"""Make sure the probabilities sum to 1."""
|
||||
if not abs(sum(v.values()) - 1) < 1e-6:
|
||||
raise ValueError("Green action probabilities must sum to 1")
|
||||
return v
|
||||
|
||||
raise NotImplementedError
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]:
|
||||
"""Ensure that the keys of the probability dictionary cover all integers from 0 to N."""
|
||||
if not all((i in v) for i in range(len(v))):
|
||||
raise ValueError(
|
||||
"Green action probabilities must be defined as a mapping where the keys are consecutive integers "
|
||||
"from 0 to N."
|
||||
)
|
||||
return v
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
settings: Dict = {},
|
||||
) -> None:
|
||||
# If the action probabilities are not specified, create equal probabilities for all actions
|
||||
if "action_probabilities" not in settings:
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
|
||||
# If seed not specified, set it to None so that numpy chooses a random one.
|
||||
settings.setdefault("random_seed")
|
||||
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
|
||||
self.rng = np.random.default_rng(self.settings.random_seed)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Sample the action space randomly.
|
||||
|
||||
The probability of each action is given by the corresponding index in ``self.probabilities``.
|
||||
|
||||
:param obs: Current observation for this agent, not used in ProbabilisticAgent
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, not used in ProbabilisticAgent
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
|
||||
return self.action_manager.get_action(choice)
|
||||
|
||||
@@ -7,9 +7,10 @@ from pydantic import BaseModel, ConfigDict
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents import ProbabilisticAgent
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -164,8 +165,7 @@ class PrimaiteGame:
|
||||
agent_actions = {}
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
@@ -299,7 +299,7 @@ class PrimaiteGame:
|
||||
if service_type == "DatabaseService":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
new_service.password = opt.get("backup_server_ip", None)
|
||||
new_service.password = opt.get("db_password", None)
|
||||
new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip")))
|
||||
if service_type == "FTPServer":
|
||||
if "options" in service_cfg:
|
||||
@@ -412,20 +412,19 @@ class PrimaiteGame:
|
||||
# CREATE REWARD FUNCTION
|
||||
reward_function = RewardFunction.from_config(reward_function_cfg)
|
||||
|
||||
# OTHER AGENT SETTINGS
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
new_agent = RandomAgent(
|
||||
settings = agent_cfg.get("agent_settings")
|
||||
new_agent = ProbabilisticAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
settings=settings,
|
||||
)
|
||||
elif agent_type == "ProxyAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
@@ -435,6 +434,8 @@ class PrimaiteGame:
|
||||
)
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
"source": [
|
||||
"## Scenario\n",
|
||||
"\n",
|
||||
"The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n",
|
||||
"The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database. Occasionally, admins need to access the database directly from the clients.\n",
|
||||
"\n",
|
||||
"[<img src=\"_package_data/uc2_network.png\" width=\"500\"/>](_package_data/uc2_network.png)\n",
|
||||
"\n",
|
||||
@@ -46,7 +46,9 @@
|
||||
"source": [
|
||||
"## Green agent\n",
|
||||
"\n",
|
||||
"There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available."
|
||||
"There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available.\n",
|
||||
"\n",
|
||||
"Sometimes, the green agents send a request directly to the database to check that it is reachable."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -68,7 +70,9 @@
|
||||
"source": [
|
||||
"## Blue agent\n",
|
||||
"\n",
|
||||
"The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router."
|
||||
"The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router.\n",
|
||||
"\n",
|
||||
"However, these rules will also impact greens' ability to check the database connection. The blue agent should only block the infected client, it should let the other client connect freely. Once the attack has begun, automated traffic monitoring will detect it as suspicious network traffic. The blue agent's observation space will show this as an increase in the number of malicious network events (NMNE) on one of the network interfaces. To achieve optimal reward, the agent should only block the client which has the non-zero outbound NMNE."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -101,9 +105,11 @@
|
||||
"The red agent does not use information about the state of the network to decide its action.\n",
|
||||
"\n",
|
||||
"### Green\n",
|
||||
"The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n",
|
||||
"The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, it will do nothing 30% of the time, send a web request 60% of the time, and send a db status check 10% of the time.\n",
|
||||
"\n",
|
||||
"When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender."
|
||||
"When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender.\n",
|
||||
"\n",
|
||||
"Also, when the green agent is blocked from checking the database status, it causes a small negative reward."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -322,9 +328,10 @@
|
||||
"source": [
|
||||
"## Reward Function\n",
|
||||
"\n",
|
||||
"The blue agent's reward is calculated using two measures:\n",
|
||||
"The blue agent's reward is calculated using these measures:\n",
|
||||
"1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n",
|
||||
"2. Whether each green agents' most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n",
|
||||
"3. Whether each green agents' most recent DB status check was successful (+1 for a successful connection, -1 for no connection).\n",
|
||||
"\n",
|
||||
"The file status reward and the two green-agent-related rewards are averaged to get a total step reward.\n"
|
||||
]
|
||||
@@ -346,7 +353,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
@@ -356,7 +365,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
@@ -403,7 +414,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -1.0 when green agents try to access the webpage."
|
||||
"The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -0.8 when green agents try to access the webpage."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -432,7 +443,7 @@
|
||||
"source": [
|
||||
"for step in range(35):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -477,6 +488,13 @@
|
||||
"File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -506,7 +524,7 @@
|
||||
"\n",
|
||||
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
|
||||
"\n",
|
||||
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -518,9 +536,9 @@
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
|
||||
"print(f\"Blue reward:{reward}\" )"
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
|
||||
"print(f\"Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -539,24 +557,24 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(26) # Block client 1\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(27) # Block client 2\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"for step in range(30):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, even though the red agent executes an attack, the reward stays at 1.0"
|
||||
"Now, even though the red agent executes an attack, the reward stays at 0.8."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -575,6 +593,46 @@
|
||||
"obs['ACL']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can slightly increase the reward by unblocking the client which isn't being used by the attacker. If node 6 has outbound NMNEs, let's unblock client 2, and if node 7 has outbound NMNEs, let's unblock client 1."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's unblock client 2\n",
|
||||
" env.step(34) # remove ACL rule 6\n",
|
||||
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" env.step(33) # remove ACL rule 5\n",
|
||||
"else:\n",
|
||||
" print(\"something went wrong, neither client has NMNEs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for step in range(30):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -590,13 +648,6 @@
|
||||
"source": [
|
||||
"env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||
|
||||
@@ -23,7 +24,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
super().__init__()
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
@@ -78,7 +79,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
|
||||
@@ -40,7 +39,7 @@ class SessionMode(Enum):
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
def __init__(self, game_cfg: Dict):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
@@ -57,7 +56,7 @@ class PrimaiteSession:
|
||||
self.io_manager: Optional["SessionIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game: PrimaiteGame = game
|
||||
self.game_cfg: Dict = game_cfg
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
@@ -93,9 +92,7 @@ class PrimaiteSession:
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
io_manager = SessionIO(SessionIOSettings(**io_settings))
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
sess = cls(game=game)
|
||||
sess = cls(game_cfg=cfg)
|
||||
sess.io_manager = io_manager
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
|
||||
@@ -146,6 +146,9 @@ def arcd_uc2_network() -> Network:
|
||||
)
|
||||
client_1.power_on()
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
db_client_1 = client_1.software_manager.install(DatabaseClient)
|
||||
db_client_1 = client_1.software_manager.software.get("DatabaseClient")
|
||||
db_client_1.run()
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
db_manipulation_bot.configure(
|
||||
@@ -165,6 +168,9 @@ def arcd_uc2_network() -> Network:
|
||||
start_up_duration=0,
|
||||
)
|
||||
client_2.power_on()
|
||||
client_2.software_manager.install(DatabaseClient)
|
||||
db_client_2 = client_2.software_manager.software.get("DatabaseClient")
|
||||
db_client_2.run()
|
||||
web_browser = client_2.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
|
||||
@@ -194,67 +200,10 @@ def arcd_uc2_network() -> Network:
|
||||
database_server.power_on()
|
||||
network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3])
|
||||
|
||||
ddl = """
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name VARCHAR(50) NOT NULL,
|
||||
email VARCHAR(50) NOT NULL,
|
||||
age INT,
|
||||
city VARCHAR(50),
|
||||
occupation VARCHAR(50)
|
||||
);"""
|
||||
|
||||
user_insert_statements = [
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');",
|
||||
# noqa
|
||||
"INSERT INTO user (name, email, age, city, occupation) "
|
||||
"VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');",
|
||||
# noqa
|
||||
]
|
||||
database_server.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
|
||||
database_service.start()
|
||||
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
|
||||
database_service._process_sql(ddl, None, None) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database_service._process_sql(insert_statement, None, None) # noqa
|
||||
|
||||
# Web Server
|
||||
web_server = Server(
|
||||
|
||||
@@ -15,6 +15,8 @@ class IPProtocol(Enum):
|
||||
.. _List of IPProtocols:
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
"""Placeholder for a non-protocol."""
|
||||
TCP = "tcp"
|
||||
"""Transmission Control Protocol."""
|
||||
UDP = "udp"
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
@@ -25,6 +26,8 @@ class DatabaseClient(Application):
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
_last_connection_successful: Optional[bool] = None
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
@@ -32,14 +35,30 @@ class DatabaseClient(Application):
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: self.execute()))
|
||||
return rm
|
||||
|
||||
def execute(self) -> bool:
|
||||
"""Execution definition for db client: perform a select query."""
|
||||
if self.connections:
|
||||
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
|
||||
else:
|
||||
can_connect = self.check_connection(connection_id=str(uuid4()))
|
||||
self._last_connection_successful = can_connect
|
||||
return can_connect
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Describes the current state of the ACLRule.
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
return super().describe_state()
|
||||
state = super().describe_state()
|
||||
# list of connections that were established or verified during this step.
|
||||
state["last_connection_successful"] = self._last_connection_successful
|
||||
return state
|
||||
|
||||
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
|
||||
"""
|
||||
@@ -65,6 +84,18 @@ class DatabaseClient(Application):
|
||||
)
|
||||
return self.connected
|
||||
|
||||
def check_connection(self, connection_id: str) -> bool:
|
||||
"""Check whether the connection can be successfully re-established.
|
||||
|
||||
:param connection_id: connection ID to check
|
||||
:type connection_id: str
|
||||
:return: Whether the connection was successfully re-established.
|
||||
:rtype: bool
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
|
||||
def _connect(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -32,12 +35,10 @@ class DataManipulationAttackStage(IntEnum):
|
||||
"Signifies that the attack has failed."
|
||||
|
||||
|
||||
class DataManipulationBot(DatabaseClient):
|
||||
class DataManipulationBot(Application):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
payload: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
|
||||
@@ -46,8 +47,31 @@ class DataManipulationBot(DatabaseClient):
|
||||
"Whether to repeat attacking once finished."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DataManipulationBot"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DataManipulationBot"
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
@property
|
||||
def _host_db_client(self) -> DatabaseClient:
|
||||
"""Return the database client that is installed on the same machine as the DataManipulationBot."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
@@ -76,8 +100,8 @@ class DataManipulationBot(DatabaseClient):
|
||||
:param repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
self.server_ip_address = server_ip_address
|
||||
self.payload = payload
|
||||
self.server_password = server_password
|
||||
self.payload = payload
|
||||
self.port_scan_p_of_success = port_scan_p_of_success
|
||||
self.data_manipulation_p_of_success = data_manipulation_p_of_success
|
||||
self.repeat = repeat
|
||||
@@ -123,15 +147,21 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
:param p_of_success: Probability of successfully performing data manipulation, by default 0.1.
|
||||
"""
|
||||
if self._host_db_client is None:
|
||||
self.attack_stage = DataManipulationAttackStage.FAILED
|
||||
return
|
||||
|
||||
self._host_db_client.server_ip_address = self.server_ip_address
|
||||
self._host_db_client.server_password = self.server_password
|
||||
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
|
||||
# perform the actual data manipulation attack
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing data manipulation")
|
||||
# perform the attack
|
||||
if not len(self.connections):
|
||||
self.connect()
|
||||
if len(self.connections):
|
||||
self.query(self.payload)
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
|
||||
@@ -6,7 +6,6 @@ from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -28,7 +27,7 @@ class DoSAttackStage(IntEnum):
|
||||
"Attack is completed."
|
||||
|
||||
|
||||
class DoSBot(DatabaseClient, Application):
|
||||
class DoSBot(DatabaseClient):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
|
||||
@@ -199,7 +199,7 @@ class WebBrowser(Application):
|
||||
def state(self) -> Dict:
|
||||
"""Return the contents of this dataclass as a dict for use with describe_state method."""
|
||||
if self.status == self._HistoryItemStatus.LOADED:
|
||||
outcome = self.response_code
|
||||
outcome = self.response_code.value
|
||||
else:
|
||||
outcome = self.status.value
|
||||
return {"url": self.url, "outcome": outcome}
|
||||
|
||||
@@ -221,6 +221,18 @@ class DatabaseService(Service):
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "SELECT * FROM pg_stat_activity":
|
||||
# Check if the connection is active.
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 401, "data": False}
|
||||
else:
|
||||
# Invalid query
|
||||
self.sys_log.info(f"{self.name}: Invalid {query}")
|
||||
|
||||
@@ -21,7 +21,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -40,7 +40,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -40,7 +40,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -65,7 +65,7 @@ game:
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -31,7 +31,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -29,7 +29,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent):
|
||||
)
|
||||
self.most_recent_action: Tuple[str, Dict]
|
||||
|
||||
def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return the agent's most recent action, formatted in CAOS format."""
|
||||
return self.most_recent_action
|
||||
|
||||
@@ -497,7 +497,6 @@ def game_and_agent():
|
||||
]
|
||||
|
||||
action_space = ActionManager(
|
||||
game=game,
|
||||
actions=actions, # ALL POSSIBLE ACTIONS
|
||||
nodes=[
|
||||
{
|
||||
|
||||
@@ -21,15 +21,15 @@ class TestPrimaiteSession:
|
||||
raise AssertionError
|
||||
|
||||
assert session is not None
|
||||
assert session.game.simulation
|
||||
assert len(session.game.agents) == 3
|
||||
assert len(session.game.rl_agents) == 1
|
||||
assert session.env.game.simulation
|
||||
assert len(session.env.game.agents) == 3
|
||||
assert len(session.env.game.rl_agents) == 1
|
||||
|
||||
assert session.policy
|
||||
assert session.env
|
||||
|
||||
assert session.game.simulation.network
|
||||
assert len(session.game.simulation.network.nodes) == 10
|
||||
assert session.env.game.simulation.network
|
||||
assert len(session.env.game.simulation.network.nodes) == 10
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_start_session(self, temp_primaite_session):
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from primaite.game.agent.rewards import WebpageUnavailablePenalty
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests.conftest import ControlledAgent
|
||||
|
||||
|
||||
@@ -35,3 +38,45 @@ def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.7
|
||||
|
||||
|
||||
def test_uc2_rewards(game_and_agent):
|
||||
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
server_1.software_manager.install(DatabaseService)
|
||||
db_service = server_1.software_manager.software.get("DatabaseService")
|
||||
db_service.start()
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
|
||||
db_client.run()
|
||||
|
||||
router: Router = game.simulation.network.get_node_by_hostname("router")
|
||||
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2)
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
|
||||
db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state)
|
||||
assert reward_value == 1.0
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
|
||||
db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state)
|
||||
assert reward_value == -1.0
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ICSObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents import ProbabilisticAgent
|
||||
|
||||
|
||||
def test_probabilistic_agent():
|
||||
"""
|
||||
Check that the probabilistic agent selects actions with approximately the right probabilities.
|
||||
|
||||
Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator),
|
||||
we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values
|
||||
were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number
|
||||
generation.
|
||||
"""
|
||||
N_TRIALS = 10_000
|
||||
P_DO_NOTHING = 0.1
|
||||
P_NODE_APPLICATION_EXECUTE = 0.3
|
||||
P_NODE_FILE_DELETE = 0.6
|
||||
MIN_DO_NOTHING = 850
|
||||
MAX_DO_NOTHING = 1150
|
||||
MIN_NODE_APPLICATION_EXECUTE = 2800
|
||||
MAX_NODE_APPLICATION_EXECUTE = 3200
|
||||
MIN_NODE_FILE_DELETE = 5750
|
||||
MAX_NODE_FILE_DELETE = 6250
|
||||
|
||||
action_space = ActionManager(
|
||||
actions=[
|
||||
{"type": "DONOTHING"},
|
||||
{"type": "NODE_APPLICATION_EXECUTE"},
|
||||
{"type": "NODE_FILE_DELETE"},
|
||||
],
|
||||
nodes=[
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"applications": [{"application_name": "WebBrowser"}],
|
||||
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
|
||||
},
|
||||
],
|
||||
max_folders_per_node=2,
|
||||
max_files_per_folder=2,
|
||||
max_services_per_node=2,
|
||||
max_applications_per_node=2,
|
||||
max_nics_per_node=2,
|
||||
max_acl_rules=10,
|
||||
protocols=["TCP", "UDP", "ICMP"],
|
||||
ports=["HTTP", "DNS", "ARP"],
|
||||
act_map={
|
||||
0: {"action": "DONOTHING", "options": {}},
|
||||
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
|
||||
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
|
||||
},
|
||||
)
|
||||
observation_space = ObservationManager(ICSObservation())
|
||||
reward_function = RewardFunction()
|
||||
|
||||
pa = ProbabilisticAgent(
|
||||
agent_name="test_agent",
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
settings={
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
"random_seed": 120,
|
||||
},
|
||||
)
|
||||
|
||||
do_nothing_count = 0
|
||||
node_application_execute_count = 0
|
||||
node_file_delete_count = 0
|
||||
for _ in range(N_TRIALS):
|
||||
a = pa.get_action(0)
|
||||
if a == ("DONOTHING", {}):
|
||||
do_nothing_count += 1
|
||||
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
|
||||
node_application_execute_count += 1
|
||||
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
|
||||
node_file_delete_count += 1
|
||||
else:
|
||||
raise AssertionError("Probabilistic agent produced an unexpected action.")
|
||||
|
||||
assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING
|
||||
assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE
|
||||
assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE
|
||||
@@ -26,8 +26,8 @@ def test_create_dm_bot(dm_client):
|
||||
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot")
|
||||
|
||||
assert data_manipulation_bot.name == "DataManipulationBot"
|
||||
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
|
||||
assert data_manipulation_bot.protocol == IPProtocol.TCP
|
||||
assert data_manipulation_bot.port == Port.NONE
|
||||
assert data_manipulation_bot.protocol == IPProtocol.NONE
|
||||
assert data_manipulation_bot.payload == "DELETE"
|
||||
|
||||
|
||||
@@ -70,4 +70,13 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
|
||||
dm_bot._perform_data_manipulation(p_of_success=1.0)
|
||||
|
||||
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
|
||||
assert len(dm_bot.connections)
|
||||
assert len(dm_bot._host_db_client.connections)
|
||||
|
||||
|
||||
def test_dm_bot_fails_without_db_client(dm_client):
|
||||
dm_client.software_manager.uninstall("DatabaseClient")
|
||||
dm_bot = dm_client.software_manager.software.get("DataManipulationBot")
|
||||
assert dm_bot._host_db_client is None
|
||||
dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN
|
||||
dm_bot._perform_data_manipulation(p_of_success=1.0)
|
||||
assert dm_bot.attack_stage is DataManipulationAttackStage.FAILED
|
||||
|
||||
Reference in New Issue
Block a user