Merge branch 'dev' into bugfix/2455-notebook_updates

This commit is contained in:
Nick Todd
2024-04-16 11:03:48 +01:00
58 changed files with 1449 additions and 506 deletions

View File

@@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
args: ['--maxkb=5000']
- id: mixed-line-ending
- id: requirements-txt-fixer
- repo: http://github.com/psf/black
@@ -28,3 +28,7 @@ repos:
additional_dependencies:
- flake8-docstrings
- flake8-annotations
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
- id: nbstripout

View File

@@ -82,7 +82,7 @@ Allows configuration of the chosen observation type. These are optional.
* ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, ACL rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space.
* ``nodes``: list of nodes that will be present in this agent's observation space. The ``node_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. Each node can also be configured with services, and files that should be monitored.
* ``links``: list of links that will be present in this agent's observation space. The ``link_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config.
* ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_address_order`` sets the encoding of ip addresses as integers within the observation space.
* ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_list`` sets the encoding of ip addresses as integers within the observation space.
For more information see :py:mod:`primaite.game.agent.observations`

View File

@@ -22,35 +22,35 @@ example firewall
network:
nodes:
- ref: firewall
hostname: firewall
type: firewall
start_up_duration: 0
shut_down_duration: 0
ports:
external_port: # port 1
ip_address: 192.168.20.1
subnet_mask: 255.255.255.0
internal_port: # port 2
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
dmz_port: # port 3
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
internal_inbound_acl:
hostname: firewall
type: firewall
start_up_duration: 0
shut_down_duration: 0
ports:
external_port: # port 1
ip_address: 192.168.20.1
subnet_mask: 255.255.255.0
internal_port: # port 2
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
dmz_port: # port 3
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
internal_inbound_acl:
...
internal_outbound_acl:
...
dmz_inbound_acl:
...
dmz_outbound_acl:
...
external_inbound_acl:
...
external_outbound_acl:
...
routes:
...
internal_outbound_acl:
...
dmz_inbound_acl:
...
dmz_outbound_acl:
...
external_inbound_acl:
...
external_outbound_acl:
...
routes:
...
.. include:: common/common_node_attributes.rst

View File

@@ -25,6 +25,7 @@ Contents
simulation_components/network/nodes/switch
simulation_components/network/nodes/wireless_router
simulation_components/network/nodes/firewall
simulation_components/network/switch
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/sys_log

View File

@@ -1 +1 @@
3.0.0b7
3.0.0b8

View File

@@ -1,15 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_agent_actions: true
save_step_metadata: false
@@ -490,6 +478,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
@@ -501,6 +491,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -512,6 +504,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -523,6 +517,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
@@ -534,6 +530,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
@@ -545,6 +543,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -703,23 +703,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
@@ -730,10 +722,12 @@ agents:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:

View File

@@ -492,6 +492,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
@@ -503,6 +505,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -514,6 +518,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -525,6 +531,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
@@ -536,6 +544,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
@@ -547,6 +557,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -704,23 +716,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
@@ -1284,23 +1288,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

View File

@@ -487,7 +487,9 @@ class RouterACLAddRuleAction(AbstractAction):
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
@@ -519,7 +521,7 @@ class RouterACLAddRuleAction(AbstractAction):
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
if source_port_id == 0:
return ["do_nothing"] # invalid formulation
elif source_port_id == 1:
@@ -528,13 +530,14 @@ class RouterACLAddRuleAction(AbstractAction):
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_ip_id == 0:
if dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
dst_ip = "ALL"
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
if dest_port_id == 0:
return ["do_nothing"] # invalid formulation
@@ -553,8 +556,10 @@ class RouterACLAddRuleAction(AbstractAction):
permission_str,
protocol,
str(src_ip),
src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
dst_port,
position,
]
@@ -624,7 +629,9 @@ class FirewallACLAddRuleAction(AbstractAction):
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
@@ -665,7 +672,7 @@ class FirewallACLAddRuleAction(AbstractAction):
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_ip_id == 0:
if dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
dst_ip = "ALL"
@@ -680,6 +687,8 @@ class FirewallACLAddRuleAction(AbstractAction):
else:
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
return [
"network",
@@ -692,8 +701,10 @@ class FirewallACLAddRuleAction(AbstractAction):
permission_str,
protocol,
str(src_ip),
src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
dst_port,
position,
]
@@ -871,7 +882,8 @@ class ActionManager:
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
ip_list: List[str] = [], # to allow us to map an index to an ip address.
wildcard_list: List[str] = [], # to allow mapping from wildcard index to
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
"""Init method for ActionManager.
@@ -897,8 +909,8 @@ class ActionManager:
:type protocols: List[str]
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
:type ports: List[str]
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_address_list: Optional[List[str]]
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_list: Optional[List[str]]
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
@@ -959,8 +971,10 @@ class ActionManager:
self.protocols: List[str] = protocols
self.ports: List[str] = ports
self.ip_address_list: List[str] = ip_address_list
self.ip_address_list: List[str] = ip_list
self.wildcard_list: List[str] = wildcard_list
if self.wildcard_list == []:
self.wildcard_list = ["NONE"]
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
"num_nodes": len(self.node_names),
@@ -1195,6 +1209,24 @@ class ActionManager:
raise RuntimeError(msg)
return self.ip_address_list[ip_idx]
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
"""
Get the IP wildcard corresponding to the given index.
:param ip_idx: The index of the IP wildcard to retrieve.
:type ip_idx: int
:return: The wildcard address.
:rtype: str
"""
if wildcard_idx >= len(self.wildcard_list):
msg = (
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.wildcard_list[wildcard_idx]
def get_port_by_idx(self, port_idx: int) -> str:
"""
Get the port corresponding to the given index.
@@ -1253,37 +1285,14 @@ class ActionManager:
:return: The constructed ActionManager.
:rtype: ActionManager
"""
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
# the nodes in the simulation.
# TODO: refactor. Options:
# 1: This should be pulled out into it's own function for clarity
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
# go through the nodes here.
ip_address_order = cfg["options"].pop("ip_address_order", {})
ip_address_list = []
for entry in ip_address_order:
node_name = entry["node_name"]
nic_num = entry["nic_num"]
node_obj = game.simulation.network.get_node_by_hostname(node_name)
ip_address = node_obj.network_interface[nic_num].ip_address
ip_address_list.append(ip_address)
if not ip_address_list:
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
for node_name in node_names:
node_obj = game.simulation.network.get_node_by_hostname(node_name)
if node_obj is None:
continue
network_interfaces = node_obj.network_interfaces
for nic_uuid, nic_obj in network_interfaces.items():
ip_address_list.append(nic_obj.ip_address)
if "ip_list" not in cfg["options"]:
cfg["options"]["ip_list"] = []
obj = cls(
actions=cfg["action_list"],
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=ip_address_list,
act_map=cfg.get("action_map"),
)

View File

@@ -43,7 +43,10 @@ class LinkObservation(AbstractObservation, identifier="LINK"):
"""
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
self.where[-1] = "<->".join(self.where[-1].split("<->")[::-1]) # try swapping endpoint A and B
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]

View File

@@ -189,7 +189,6 @@ class ObservationManager:
"""
if config is None:
return cls(NullObservation())
print(config)
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))

View File

@@ -26,7 +26,7 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING
from typing_extensions import Never
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Iterable[str | int] | None
class AbstractReward:
@@ -293,6 +294,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0

View File

@@ -1,8 +1,13 @@
from typing import Dict, Tuple
import random
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
class RandomAgent(AbstractScriptedAgent):
@@ -19,3 +24,60 @@ class RandomAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())
class PeriodicAgent(AbstractScriptedAgent):
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
class Settings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 20
"The timestep at which an agent begins performing it's actions."
start_variance: int = 5
"Deviation around the start step."
frequency: int = 5
"The number of timesteps to wait between performing actions."
variance: int = 0
"The amount the frequency can randomly change to."
max_executions: int = 999999
"Maximum number of times the agent can execute its action."
def __init__(
self,
agent_name: str,
action_space: ActionManager,
observation_space: ObservationManager,
reward_function: RewardFunction,
settings: Optional[Settings] = None,
) -> None:
"""Initialise PeriodicAgent."""
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.settings = settings or PeriodicAgent.Settings()
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
self.num_executions = 0
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep when the next execute action should be taken.
:type timestep: int
:param variance: Uniform random variance applied to the timestep
:type variance: int
"""
random_increment = random.randint(-variance, variance)
self.next_execution_timestep = timestep + random_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
return "DONOTHING", {}

View File

@@ -0,0 +1,78 @@
import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class TAP001(AbstractScriptedAgent):
"""
TAP001 | Mobile Malware -- Ransomware Variant.
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
next_execution_timestep: int = 0
starting_node_idx: int = 0
installed: bool = False
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute the ransomware application.
This application acts a wrapper around the kill-chain, similar to green-analyst and
the previous UC2 data manipulation bot.
:param obs: Current observation for this agent.
:type obs: ObsType
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
if not self.installed:
self.installed = True
return "NODE_APPLICATION_INSTALL", {
"node_id": self.starting_node_idx,
"application_name": "RansomwareScript",
"ip_address": self.ip_address,
}
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
for n, act in self.action_manager.action_map.items():
if not act[0] == "NODE_APPLICATION_INSTALL":
continue
if act[1]["node_id"] == self.starting_node_idx:
self.ip_address = act[1]["ip_address"]
return
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)

View File

@@ -11,7 +11,10 @@ from primaite.game.agent.observations.observation_manager import ObservationMana
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.airspace import AIR_SPACE
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -26,6 +29,7 @@ from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
@@ -43,6 +47,7 @@ APPLICATION_TYPES_MAPPING = {
"DatabaseClient": DatabaseClient,
"DataManipulationBot": DataManipulationBot,
"DoSBot": DoSBot,
"RansomwareScript": RansomwareScript,
}
"""List of available applications that can be installed on nodes in the PrimAITE Simulation."""
@@ -128,6 +133,8 @@ class PrimaiteGame:
"""
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
self.pre_timestep()
if self.step_counter == 0:
state = self.get_sim_state()
for agent in self.agents.values():
@@ -172,6 +179,10 @@ class PrimaiteGame:
response=response,
)
def pre_timestep(self) -> None:
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
self.simulation.pre_timestep(self.step_counter)
def advance_timestep(self) -> None:
"""Advance timestep."""
self.step_counter += 1
@@ -211,6 +222,7 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
AIR_SPACE.clear()
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
@@ -268,6 +280,9 @@ class PrimaiteGame:
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
else:
msg = f"invalid node type {n_type} in config"
@@ -339,6 +354,19 @@ class PrimaiteGame:
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "RansomwareScript":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
ransomware_encrypt_p_of_success=float(
opt.get("ransomware_encrypt_p_of_success", "0.1")
),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:
opt = application_cfg["options"]
@@ -383,6 +411,7 @@ class PrimaiteGame:
for link_cfg in links_cfg:
node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"])
node_b = net.get_node_by_hostname(link_cfg["endpoint_b_hostname"])
if isinstance(node_a, Switch):
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
else:
@@ -423,6 +452,16 @@ class PrimaiteGame:
reward_function=reward_function,
settings=settings,
)
elif agent_type == "PeriodicAgent":
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
new_agent = PeriodicAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
elif agent_type == "ProxyAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = ProxyAgent(
@@ -443,6 +482,15 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
elif agent_type == "TAP001":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = TAP001(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
else:
msg = f"Configuration error: {agent_type} is not a valid agent type."
_LOGGER.error(msg)

View File

@@ -26,6 +26,9 @@ class PrimaiteGymEnv(gymnasium.Env):
def __init__(self, game_config: Dict):
"""Initialise the environment."""
super().__init__()
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
@@ -49,6 +52,7 @@ class PrimaiteGymEnv(gymnasium.Env):
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.pre_timestep()
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
@@ -224,6 +228,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.pre_timestep()
self.game.apply_agent_actions()
# 2. Advance timestep

View File

@@ -29,10 +29,12 @@ class PrimaiteIO:
"""Whether to save a log of all agents' actions every step."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = False
save_pcap_logs: bool = True
"""Whether to save PCAP logs."""
save_sys_logs: bool = False
save_sys_logs: bool = True
"""Whether to save system logs."""
write_sys_log_to_terminal: bool = False
"""Whether to write the sys log to the terminal."""
def __init__(self, settings: Optional[Settings] = None) -> None:
"""
@@ -47,6 +49,7 @@ class PrimaiteIO:
SIM_OUTPUT.path = self.session_path / "simulation_output"
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
@@ -93,4 +96,5 @@ class PrimaiteIO:
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
new = cls(settings=cls.Settings(**config))
return new

View File

@@ -14,6 +14,7 @@ class _SimOutput:
)
self.save_pcap_logs: bool = False
self.save_sys_logs: bool = False
self.write_sys_log_to_terminal: bool = False
@property
def path(self) -> Path:

View File

@@ -258,8 +258,7 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
}
},
"nbformat": 4,
"nbformat_minor": 2

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "03b2013a-b7d1-47ee-b08c-8dab83833720",
"id": "0",
"metadata": {},
"source": [
"# PrimAITE Router Simulation Demo\n",
@@ -12,7 +12,7 @@
},
{
"cell_type": "raw",
"id": "c8bb5698-e746-4e90-9c2f-efe962acdfa0",
"id": "1",
"metadata": {},
"source": [
" +------------+\n",
@@ -48,7 +48,7 @@
},
{
"cell_type": "markdown",
"id": "415d487c-6457-497d-85d6-99439b3541e7",
"id": "2",
"metadata": {},
"source": [
"## The Network\n",
@@ -60,7 +60,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "de57ac8c-5b28-4847-a759-2ceaf5593329",
"id": "3",
"metadata": {
"tags": []
},
@@ -72,7 +72,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a1e2e4df-67c0-4584-ab27-47e2c7c7fcd2",
"id": "4",
"metadata": {
"tags": []
},
@@ -83,7 +83,7 @@
},
{
"cell_type": "markdown",
"id": "fb052c56-e9ca-4093-9115-d0c440b5ff53",
"id": "5",
"metadata": {},
"source": [
"Most of the Network components have a `.show()` function that prints a table of information about that object. We can view the Nodes and Links on the Network by calling `network.show()`."
@@ -92,7 +92,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cc199741-ef2e-47f5-b2f0-e20049ccf40f",
"id": "6",
"metadata": {
"tags": []
},
@@ -103,7 +103,7 @@
},
{
"cell_type": "markdown",
"id": "76d2b7e9-280b-4741-a8b3-a84bed219fac",
"id": "7",
"metadata": {
"tags": []
},
@@ -115,7 +115,7 @@
},
{
"cell_type": "markdown",
"id": "84113002-843e-4cab-b899-667b50f25f6b",
"id": "8",
"metadata": {},
"source": [
"### Router Nodes\n",
@@ -125,7 +125,7 @@
},
{
"cell_type": "markdown",
"id": "bf63a178-eee5-4669-bf64-13aea7ecf6cb",
"id": "9",
"metadata": {},
"source": [
"Calling `router.show()` displays the Ethernet interfaces on the Router. If you need a table in markdown format, pass `markdown=True`."
@@ -134,7 +134,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e76d1854-961e-438c-b40f-77fd9c3abe38",
"id": "10",
"metadata": {
"tags": []
},
@@ -145,7 +145,7 @@
},
{
"cell_type": "markdown",
"id": "e000540c-687c-4254-870c-1d814603bdbf",
"id": "11",
"metadata": {},
"source": [
"Calling `router.arp.show()` displays the Router ARP Cache."
@@ -154,7 +154,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "92de8b42-92d7-4934-9c12-50bf724c9eb2",
"id": "12",
"metadata": {
"tags": []
},
@@ -165,7 +165,7 @@
},
{
"cell_type": "markdown",
"id": "a9ff7ee8-9482-44de-9039-b684866bdc82",
"id": "13",
"metadata": {},
"source": [
"Calling `router.acl.show()` displays the Access Control List."
@@ -174,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5922282a-d22b-4e55-9176-f3f3654c849f",
"id": "14",
"metadata": {
"tags": []
},
@@ -185,7 +185,7 @@
},
{
"cell_type": "markdown",
"id": "71c87884-f793-4c9f-b004-5b0df86cf585",
"id": "15",
"metadata": {},
"source": [
"Calling `router.router_table.show()` displays the static routes the Router provides."
@@ -194,7 +194,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "327203be-f475-4727-82a1-e992d3b70ed8",
"id": "16",
"metadata": {
"tags": []
},
@@ -205,7 +205,7 @@
},
{
"cell_type": "markdown",
"id": "eef561a8-3d39-4c8b-bbc8-e8b10b8ed25f",
"id": "17",
"metadata": {},
"source": [
"Calling `router.sys_log.show()` displays the Router system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
@@ -214,7 +214,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3d0aa004-b10c-445f-aaab-340e0e716c74",
"id": "18",
"metadata": {
"tags": []
},
@@ -225,7 +225,7 @@
},
{
"cell_type": "markdown",
"id": "25630c90-c54e-4b5d-8bf4-ad1b0722e126",
"id": "19",
"metadata": {},
"source": [
"### Switch Nodes\n",
@@ -235,7 +235,7 @@
},
{
"cell_type": "markdown",
"id": "4879394d-2981-40de-a229-e19b09a34e6e",
"id": "20",
"metadata": {},
"source": [
"Calling `switch.show()` displays the Switch orts on the Switch."
@@ -244,7 +244,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e7fd439b-5442-4e9d-9e7d-86dacb77f458",
"id": "21",
"metadata": {
"tags": []
},
@@ -255,7 +255,7 @@
},
{
"cell_type": "markdown",
"id": "beb8dbd6-7250-4ac9-9fa2-d2a9c0e5fd19",
"id": "22",
"metadata": {
"tags": []
},
@@ -266,7 +266,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d06e1310-4a77-4315-a59f-cb1b49ca2352",
"id": "23",
"metadata": {
"tags": []
},
@@ -277,7 +277,7 @@
},
{
"cell_type": "markdown",
"id": "fda75ac3-8123-4234-8f36-86547891d8df",
"id": "24",
"metadata": {},
"source": [
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
@@ -286,7 +286,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a0d984b7-a7c1-4bbd-aa5a-9d3caecb08dc",
"id": "25",
"metadata": {
"tags": []
},
@@ -297,7 +297,7 @@
},
{
"cell_type": "markdown",
"id": "2f1d99ad-db4f-4baf-8a35-e1d95f269586",
"id": "26",
"metadata": {},
"source": [
"### Computer/Server Nodes\n",
@@ -307,7 +307,7 @@
},
{
"cell_type": "markdown",
"id": "c9e2251a-1b47-46e5-840f-7fec3e39c5aa",
"id": "27",
"metadata": {
"tags": []
},
@@ -318,7 +318,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "656c37f6-b145-42af-9714-8d2886d0eff8",
"id": "28",
"metadata": {
"tags": []
},
@@ -329,7 +329,7 @@
},
{
"cell_type": "markdown",
"id": "f1097a49-a3da-4d79-a06d-ae8af452918f",
"id": "29",
"metadata": {},
"source": [
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
@@ -338,7 +338,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "66b267d6-2308-486a-b9aa-cb8d3bcf0753",
"id": "30",
"metadata": {
"tags": []
},
@@ -349,7 +349,7 @@
},
{
"cell_type": "markdown",
"id": "0d1fcad8-5b1a-4d8b-a49f-aa54a95fcaf0",
"id": "31",
"metadata": {},
"source": [
"Calling `switch.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
@@ -358,7 +358,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1b5debe8-ef1b-445d-8fa9-6a45568f21f3",
"id": "32",
"metadata": {
"tags": []
},
@@ -369,7 +369,7 @@
},
{
"cell_type": "markdown",
"id": "fcfa1773-798c-4ada-9318-c3ad928217da",
"id": "33",
"metadata": {},
"source": [
"## Basic Network Comms Check\n",
@@ -380,7 +380,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "495b7de4-b6ce-41a6-9114-f74752ab4491",
"id": "34",
"metadata": {
"tags": []
},
@@ -391,7 +391,7 @@
},
{
"cell_type": "markdown",
"id": "3e13922a-217f-4f4e-99b6-57a07613cade",
"id": "35",
"metadata": {},
"source": [
"We'll first ping client_1's default gateway."
@@ -400,7 +400,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a38abb71-994e-49e8-8f51-e9a550e95b99",
"id": "36",
"metadata": {
"tags": []
},
@@ -412,7 +412,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8388e1e9-30e3-4534-8e5a-c6e9144149d2",
"id": "37",
"metadata": {
"tags": []
},
@@ -423,7 +423,7 @@
},
{
"cell_type": "markdown",
"id": "02c76d5c-d954-49db-912d-cb9c52f46375",
"id": "38",
"metadata": {},
"source": [
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
@@ -432,7 +432,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ff8e976a-c16b-470c-8923-325713a30d6c",
"id": "39",
"metadata": {
"tags": []
},
@@ -443,7 +443,7 @@
},
{
"cell_type": "markdown",
"id": "80280404-a5ab-452f-8a02-771a0d7496b1",
"id": "40",
"metadata": {},
"source": [
"And finally, we'll ping the web server."
@@ -452,7 +452,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c4163f8d-6a72-410c-9f5c-4f881b7de45e",
"id": "41",
"metadata": {
"tags": []
},
@@ -463,7 +463,7 @@
},
{
"cell_type": "markdown",
"id": "1194c045-ba77-4427-be30-ed7b5b224850",
"id": "42",
"metadata": {},
"source": [
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
@@ -472,7 +472,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e79a523a-5780-45b6-8798-c434e0e522bd",
"id": "43",
"metadata": {
"tags": []
},
@@ -483,7 +483,7 @@
},
{
"cell_type": "markdown",
"id": "5928f6dd-1006-45e3-99f3-8f311a875faa",
"id": "44",
"metadata": {},
"source": [
"## Advanced Network Usage\n",
@@ -493,7 +493,7 @@
},
{
"cell_type": "markdown",
"id": "5e023ef3-7d18-4006-96ee-042a06a481fc",
"id": "45",
"metadata": {},
"source": [
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
@@ -502,7 +502,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "603cf913-e261-49da-a7dd-85e1bb6dec56",
"id": "46",
"metadata": {
"tags": []
},
@@ -513,7 +513,7 @@
},
{
"cell_type": "markdown",
"id": "5cf962a4-20e6-44ae-9748-7fc5267ae111",
"id": "47",
"metadata": {},
"source": [
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
@@ -522,7 +522,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e047de00-3de4-4823-b26a-2c8d64c7a663",
"id": "48",
"metadata": {
"tags": []
},
@@ -533,7 +533,7 @@
},
{
"cell_type": "markdown",
"id": "bdc4741d-6e3e-4aec-a69c-c2e9653bd02c",
"id": "49",
"metadata": {},
"source": [
"Now we'll add an ACL to block ICMP from 192.168.10.22"
@@ -542,7 +542,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6db355ae-b99a-441b-a2c4-4ffe78f46bff",
"id": "50",
"metadata": {
"tags": []
},
@@ -562,7 +562,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a345e000-8842-4827-af96-adc0fbe390fb",
"id": "51",
"metadata": {
"tags": []
},
@@ -573,7 +573,7 @@
},
{
"cell_type": "markdown",
"id": "3a5bfd9f-04cb-493e-a86c-cd268563a262",
"id": "52",
"metadata": {},
"source": [
"Now we attempt (and fail) to ping the web server"
@@ -582,7 +582,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a4f4ff31-590f-40fb-b13d-efaa8c2720b6",
"id": "53",
"metadata": {
"tags": []
},
@@ -593,7 +593,7 @@
},
{
"cell_type": "markdown",
"id": "83e56497-097b-45cb-964e-b15c72547b38",
"id": "54",
"metadata": {},
"source": [
"We can check that the ping was actually sent by client_2 by viewing the sys log"
@@ -602,7 +602,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f62b8a4e-fd3b-4059-b108-3d4a0b18f2a0",
"id": "55",
"metadata": {
"tags": []
},
@@ -613,7 +613,7 @@
},
{
"cell_type": "markdown",
"id": "c7040311-a879-4620-86a0-55d0774156e5",
"id": "56",
"metadata": {},
"source": [
"We can check the router sys log to see why the traffic was blocked"
@@ -622,7 +622,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "7e53d776-99da-4d2c-a2a7-bd7ce27bff4c",
"id": "57",
"metadata": {
"tags": []
},
@@ -633,7 +633,7 @@
},
{
"cell_type": "markdown",
"id": "aba0bc7d-da57-477b-b34a-3688b5aab2c6",
"id": "58",
"metadata": {},
"source": [
"Now a final check to ensure that client_1 can still ping the web_server."
@@ -642,7 +642,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d542734b-7582-4af7-8254-bda3de50d091",
"id": "59",
"metadata": {
"tags": []
},
@@ -654,7 +654,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d78e9fe3-02c6-4792-944f-5622e26e0412",
"id": "60",
"metadata": {
"tags": []
},

View File

@@ -226,6 +226,15 @@ class SimComponent(BaseModel):
return
return self._request_manager(request, context)
def pre_timestep(self, timestep: int) -> None:
"""
Apply any logic that needs to happen at the beginning of the timestep to ensure correct observations/rewards.
:param timestep: what's the current time
:type timestep: int
"""
pass
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep evolution to this component.

View File

@@ -103,6 +103,10 @@ class File(FileSystemItemABC):
"""
super().apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
# reset the number of accesses to 0
self.num_access = 0

View File

@@ -427,15 +427,21 @@ class FileSystem(SimComponent):
"""Apply time step to FileSystem and its child folders and files."""
super().apply_timestep(timestep=timestep)
# apply timestep to folders
for folder_id in self.folders:
self.folders[folder_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
# reset number of file creations
self.num_file_creations = 0
# reset number of file deletions
self.num_file_deletions = 0
# apply timestep to folders
for folder_id in self.folders:
self.folders[folder_id].apply_timestep(timestep=timestep)
for folder in self.folders.values():
folder.pre_timestep(timestep)
###############################################################
# Agent actions

View File

@@ -128,6 +128,13 @@ class Folder(FileSystemItemABC):
for file_id in self.files:
self.files[file_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for file in self.files.values():
file.pre_timestep(timestep)
def _scan_timestep(self) -> None:
"""Apply the scan action timestep."""
if self.scan_countdown >= 0:

View File

@@ -1,3 +1,4 @@
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional
import matplotlib.pyplot as plt
@@ -86,6 +87,16 @@ class Network(SimComponent):
for link_id in self.links:
self.links[link_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for node in self.nodes.values():
node.pre_timestep(timestep)
for link in self.links.values():
link.pre_timestep(timestep)
@property
def router_nodes(self) -> List[Node]:
"""The Routers in the Network."""
@@ -163,10 +174,11 @@ class Network(SimComponent):
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
print(table)
if links:

View File

@@ -9,7 +9,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int:
"""
Calculate the minimum number of network switches required to connect a given number of nodes.
@@ -18,7 +18,7 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
to accommodate all nodes under this constraint.
:param num_nodes: The total number of nodes that need to be connected in the network.
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
:param max_network_interface: The maximum number of ports available on each switch. Defaults to 24.
:return: The minimum number of switches required to connect all PCs.
@@ -33,11 +33,11 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
3
"""
# Reduce the effective number of switch ports by 1 to leave space for the router
effective_switch_ports = max_switch_ports - 1
effective_network_interface = max_network_interface - 1
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
full_switches = num_nodes // effective_switch_ports
extra_pcs = num_nodes % effective_switch_ports
full_switches = num_nodes // effective_network_interface
extra_pcs = num_nodes % effective_network_interface
# Return the total number of switches required
return full_switches + (1 if extra_pcs > 0 else 0)
@@ -77,7 +77,7 @@ def create_office_lan(
# Calculate the required number of switches
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
effective_switch_ports = 23 # One port less for router connection
effective_network_interface = 23 # One port less for router connection
if pcs_ip_block_start <= num_of_switches:
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
@@ -116,7 +116,7 @@ def create_office_lan(
# Add PCs to the LAN and connect them to switches
for i in range(1, num_pcs + 1):
# Add a new edge switch if the current one is full
if switch_port == effective_switch_ports:
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)

View File

@@ -264,6 +264,9 @@ class NetworkInterface(SimComponent, ABC):
"""
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}"
def __hash__(self) -> int:
return hash(self.uuid)
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep evolution to this component.
@@ -661,6 +664,10 @@ class Link(SimComponent):
def apply_timestep(self, timestep: int) -> None:
"""Apply a timestep to the simulation."""
super().apply_timestep(timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.current_load = 0.0
@@ -895,6 +902,10 @@ class Node(SimComponent):
from primaite.simulator.system.applications.web_browser import WebBrowser
return WebBrowser
elif application_class_str == "RansomwareScript":
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
return RansomwareScript
else:
return 0
@@ -965,12 +976,15 @@ class Node(SimComponent):
table.align = "l"
table.title = f"{self.hostname} Network Interface Cards"
for port, network_interface in self.network_interface.items():
ip_address = ""
if hasattr(network_interface, "ip_address"):
ip_address = f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}"
table.add_row(
[
port,
network_interface.__class__.__name__,
network_interface.mac_address,
f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}",
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
]
@@ -1071,6 +1085,23 @@ class Node(SimComponent):
self.file_system.apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for network_interface in self.network_interfaces.values():
network_interface.pre_timestep(timestep=timestep)
for process_id in self.processes:
self.processes[process_id].pre_timestep(timestep=timestep)
for service_id in self.services:
self.services[service_id].pre_timestep(timestep=timestep)
for application_id in self.applications:
self.applications[application_id].pre_timestep(timestep=timestep)
self.file_system.pre_timestep(timestep=timestep)
def scan(self) -> bool:
"""
Scan the node and all the items within it.
@@ -1341,6 +1372,8 @@ class Node(SimComponent):
application_instance.configure(target_ip_address=IPv4Address(ip_address))
elif application_instance.name == "DataManipulationBot":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
elif application_instance.name == "RansomwareScript":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
else:
pass

View File

@@ -599,7 +599,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -612,7 +614,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -625,7 +629,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -638,7 +644,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -651,7 +659,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -664,7 +674,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)

View File

@@ -322,10 +322,12 @@ class AccessControlList(SimComponent):
action=ACLAction[request[0]],
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
src_port=None if request[3] == "ALL" else Port[request[3]],
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
dst_port=None if request[5] == "ALL" else Port[request[5]],
position=int(request[6]),
src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]),
src_port=None if request[4] == "ALL" else Port[request[4]],
dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]),
dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]),
dst_port=None if request[7] == "ALL" else Port[request[7]],
position=int(request[8]),
)
)
),
@@ -772,6 +774,13 @@ class RouterARP(ARP):
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt,
)
elif route and route == self.router.route_table.default_route:
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
return self._get_arp_cache_mac_address(
ip_address=self.router.route_table.default_route.next_hop_ip_address,
is_reattempt=True,
is_default_route_attempt=True,
)
else:
if self.router.route_table.default_route:
if not is_default_route_attempt:
@@ -822,6 +831,12 @@ class RouterARP(ARP):
return network_interface
if not is_reattempt:
if self.router.ip_is_in_router_interface_subnet(ip_address):
self.send_arp_request(ip_address)
return self._get_arp_cache_network_interface(
ip_address=ip_address, is_reattempt=True, is_default_route_attempt=is_default_route_attempt
)
route = self.router.route_table.find_best_route(ip_address)
if route and route != self.router.route_table.default_route:
self.send_arp_request(route.next_hop_ip_address)
@@ -830,6 +845,13 @@ class RouterARP(ARP):
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt,
)
elif route and route == self.router.route_table.default_route:
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
return self._get_arp_cache_network_interface(
ip_address=self.router.route_table.default_route.next_hop_ip_address,
is_reattempt=True,
is_default_route_attempt=True,
)
else:
if self.router.route_table.default_route:
if not is_default_route_attempt:
@@ -1460,6 +1482,8 @@ class Router(NetworkNode):
frame.ethernet.src_mac_addr = network_interface.mac_address
frame.ethernet.dst_mac_addr = target_mac
network_interface.send_frame(frame)
else:
self.sys_log.error(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
"""
@@ -1540,6 +1564,13 @@ class Router(NetworkNode):
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
- routes (list[dict]): List of route dicts with values:
- address (str): The destination address of the route.
- subnet_mask (str): The subnet mask of the route.
- next_hop_ip_address (str): The next hop IP for the route.
- metric (int): The metric of the route. Optional.
- default_route:
- next_hop_ip_address (str): The next hop IP for the route.
Example config:
```
@@ -1550,6 +1581,10 @@ class Router(NetworkNode):
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
},
2: {
'ip_address' : '192.168.0.1',
'subnet_mask' : '255.255.255.252',
}
},
'acl' : {
@@ -1557,6 +1592,10 @@ class Router(NetworkNode):
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
'routes' : [
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
],
'default_route': {'next_hop_ip_address': '192.168.0.2'}
}
```
@@ -1600,4 +1639,8 @@ class Router(NetworkNode):
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return router

View File

@@ -100,13 +100,8 @@ class Switch(NetworkNode):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.network_interface:
self.network_interface = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.network_interface.items():
port._connected_node = self
port.port_num = port_num
port.parent = self
port.port_num = port_num
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def show(self, markdown: bool = False):
"""

View File

@@ -8,7 +8,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket
from primaite.simulator.network.protocols.packet import DataPacket
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader
from primaite.simulator.network.utils import convert_bytes_to_megabits
_LOGGER = getLogger(__name__)
@@ -141,3 +141,37 @@ class Frame(BaseModel):
def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
"""The daa transfer size of the Frame in Mbits."""
return convert_bytes_to_megabits(self.size)
@property
def is_broadcast(self) -> bool:
"""
Determines if the Frame is a broadcast frame.
A Frame is considered a broadcast frame if the destination MAC address is set to the broadcast address
"ff:ff:ff:ff:ff:ff".
:return: True if the destination MAC address is a broadcast address, otherwise False.
"""
return self.ethernet.dst_mac_addr.lower() == "ff:ff:ff:ff:ff:ff"
@property
def is_arp(self) -> bool:
"""
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
:return: True if the Frame is an ARP packet, otherwise False.
"""
return self.udp.dst_port == Port.ARP
@property
def is_icmp(self) -> bool:
"""
Determines if the Frame is an ICMP (Internet Control Message Protocol) packet.
This check is performed by verifying if the 'icmp' attribute of the Frame instance is present (not None).
:return: True if the Frame is an ICMP packet (i.e., has an ICMP header), otherwise False.
"""
return self.icmp is not None

View File

@@ -11,6 +11,9 @@ class Port(Enum):
.. _List of Ports:
"""
UNUSED = -1
"An unused port stub."
NONE = 0
"Place holder for a non-port."
WOL = 9

View File

@@ -63,3 +63,8 @@ class Simulation(SimComponent):
"""Apply a timestep to the simulation."""
super().apply_timestep(timestep)
self.network.apply_timestep(timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.network.pre_timestep(timestep)

View File

@@ -80,7 +80,10 @@ class Application(IOSoftware):
"""
super().apply_timestep(timestep=timestep)
self.num_executions = 0 # reset number of executions
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.num_executions = 0
def _can_perform_action(self) -> bool:
"""

View File

@@ -31,6 +31,7 @@ class DatabaseClient(Application):
"""Keep track of connections that were established or verified during this step. Used for rewards."""
last_query_response: Optional[Dict] = None
"""Keep track of the latest query response. Used to determine rewards."""
_server_connection_id: Optional[str] = None
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
@@ -51,10 +52,9 @@ class DatabaseClient(Application):
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
self.num_executions += 1 # trying to connect counts as an execution
if self.connections:
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
else:
can_connect = self.check_connection(connection_id=str(uuid4()))
if not self._server_connection_id:
self.connect()
can_connect = self.check_connection(connection_id=self._server_connection_id)
self._last_connection_successful = can_connect
return can_connect
@@ -80,17 +80,21 @@ class DatabaseClient(Application):
self.server_password = server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self, connection_id: Optional[str] = None) -> bool:
def connect(self) -> bool:
"""Connect to a Database Service."""
if not self._can_perform_action():
return False
if not connection_id:
connection_id = str(uuid4())
if not self._server_connection_id:
self._server_connection_id = str(uuid4())
self.connected = self._connect(
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
server_ip_address=self.server_ip_address,
password=self.server_password,
connection_id=self._server_connection_id,
)
if not self.connected:
self._server_connection_id = None
return self.connected
def check_connection(self, connection_id: str) -> bool:
@@ -125,7 +129,7 @@ class DatabaseClient(Application):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
if self.connections.get(connection_id):
if self._server_connection_id:
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
)
@@ -149,31 +153,28 @@ class DatabaseClient(Application):
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
)
def disconnect(self, connection_id: Optional[str] = None) -> bool:
def disconnect(self) -> bool:
"""Disconnect from the Database Service."""
if not self._can_perform_action():
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
return False
# if there are no connections - nothing to disconnect
if not len(self.connections):
if not self._server_connection_id:
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
return False
# if no connection provided, disconnect the first connection
if not connection_id:
connection_id = list(self.connections.keys())[0]
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_id},
payload={"type": "disconnect", "connection_id": self._server_connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
self.remove_connection(connection_id=connection_id)
self.remove_connection(connection_id=self._server_connection_id)
self.sys_log.info(
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
f"{self.name}: DatabaseClient disconnected {self._server_connection_id} from {self.server_ip_address}"
)
self.connected = False
@@ -224,18 +225,20 @@ class DatabaseClient(Application):
# reset last query response
self.last_query_response = None
if connection_id is None:
if self.connections:
connection_id = list(self.connections.keys())[-1]
# TODO: if the most recent connection dies, it should be automatically cleared.
else:
connection_id = str(uuid4())
connection_id: str
if not self.connections.get(connection_id):
if not self.connect(connection_id=connection_id):
return False
if not connection_id:
connection_id = self._server_connection_id
if not connection_id:
self.connect()
connection_id = self._server_connection_id
if not connection_id:
msg = "Cannot run sql query, could not establish connection with the server."
self.parent.sys_log(msg)
return False
# Initialise the tracker of this ID to False
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)

View File

@@ -0,0 +1,316 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
class RansomwareAttackStage(IntEnum):
"""
Enumeration representing different attack stages of the ransomware script.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
in the simulation.
Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
DOWNLOAD = 1
"Installing the Encryption Script - Testing"
INSTALL = 2
"The stage where logon procedures are simulated."
ACTIVATE = 3
"Operating Status Changes"
PROPAGATE = 4
"Represents the stage of performing a horizontal port scan on the target."
COMMAND_AND_CONTROL = 5
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
PAYLOAD = 6
"Stage of actively attacking the target."
SUCCEEDED = 7
"Indicates the attack has been successfully completed."
FAILED = 8
"Signifies that the attack has failed."
class RansomwareScript(Application):
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
:ivar payload: The attack stage query payload. (Default Corrupt)
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:ivar repeat: Whether to repeat attacking once finished.
"""
server_ip_address: Optional[IPv4Address] = None
"""IP address of node which hosts the database."""
server_password: Optional[str] = None
"""Password required to access the database."""
payload: Optional[str] = "ENCRYPT"
"Payload String for the payload stage"
target_scan_p_of_success: float = 0.9
"Probability of the target scan succeeding: Default 0.9"
c2_beacon_p_of_success: float = 0.9
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
ransomware_encrypt_p_of_success: float = 0.9
"Probability of the ransomware attack succeeding: Default 0.9"
repeat: bool = False
"If true, the Denial of Service bot will keep performing the attack."
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
"The ransomware attack stage. See RansomwareAttackStage Class"
def __init__(self, **kwargs):
kwargs["name"] = "RansomwareScript"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
return state
@property
def _host_db_client(self) -> DatabaseClient:
"""Return the database client that is installed on the same machine as the Ransomware Script."""
db_client = self.software_manager.software.get("DatabaseClient")
if db_client is None:
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
return db_client
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
name="execute",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
)
return rm
def _activate(self):
"""
Simulate the install process as the initial stage of the attack.
Advances the attack stage to 'ACTIVATE' attack state.
"""
if self.attack_stage == RansomwareAttackStage.INSTALL:
self.sys_log.info(f"{self.name}: Activated!")
self.attack_stage = RansomwareAttackStage.ACTIVATE
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, triggering the application loop.
:param timestep: The timestep value to update the bot's state.
"""
pass
def run(self) -> bool:
"""Calls the parent classes execute method before starting the application loop."""
super().run()
return True
def _application_loop(self) -> bool:
"""
The main application loop of the script, handling the attack process.
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if not self._can_perform_action():
return False
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self.attack_stage = RansomwareAttackStage.NOT_STARTED
self._local_download()
self._install()
self._activate()
self._perform_target_scan()
self._setup_beacon()
self._perform_ransomware_encrypt()
if self.repeat and self.attack_stage in (
RansomwareAttackStage.SUCCEEDED,
RansomwareAttackStage.FAILED,
):
self.attack_stage = RansomwareAttackStage.NOT_STARTED
return True
else:
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
return False
def configure(
self,
server_ip_address: IPv4Address,
server_password: Optional[str] = None,
payload: Optional[str] = None,
target_scan_p_of_success: Optional[float] = None,
c2_beacon_p_of_success: Optional[float] = None,
ransomware_encrypt_p_of_success: Optional[float] = None,
repeat: bool = True,
):
"""
Configure the Ransomware Script to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The attack stage query (Encrypt / Delete)
:param target_scan_p_of_success: The probability of success for the target scan stage.
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:param repeat: Whether to repeat attacking once finished.
"""
if server_ip_address:
self.server_ip_address = server_ip_address
if server_password:
self.server_password = server_password
if payload:
self.payload = payload
if target_scan_p_of_success:
self.target_scan_p_of_success = target_scan_p_of_success
if c2_beacon_p_of_success:
self.c2_beacon_p_of_success = c2_beacon_p_of_success
if ransomware_encrypt_p_of_success:
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
if repeat:
self.repeat = repeat
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
f"{repeat=}."
)
def _install(self):
"""
Simulate the install stage in the kill-chain.
Advances the attack stage to 'ACTIVATE' if successful.
From this attack stage onwards.
the ransomware application is now visible from this point onwardin the observation space.
"""
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
downloads_folder = self.file_system.get_folder(folder_name="downloads")
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
ransomware_file.num_access += 1
self.attack_stage = RansomwareAttackStage.INSTALL
def _setup_beacon(self):
"""
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
by default the success rate is 0.5
"""
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
if simulate_trial(self.c2_beacon_p_of_success):
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
if c2c_setup:
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
def _perform_target_scan(self):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PROPAGATE` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
# perform a port scan to identify that the SQL port is open on the server
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
if simulate_trial(self.target_scan_p_of_success):
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
if port_is_open:
self.attack_stage = RansomwareAttackStage.PROPAGATE
def attack(self) -> bool:
"""Perform the attack steps after opening the application."""
if not self._can_perform_action():
_LOGGER.debug("Ransomware application is unable to perform it's actions.")
self.run()
self.num_executions += 1
return self._application_loop()
def _perform_ransomware_encrypt(self):
"""
Execute the Ransomware Encrypt payload on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
"""
if self._host_db_client is None:
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
self.attack_stage = RansomwareAttackStage.FAILED
return
self._host_db_client.server_ip_address = self.server_ip_address
self._host_db_client.server_password = self.server_password
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
if simulate_trial(self.ransomware_encrypt_p_of_success):
self.sys_log.info(f"{self.name}: Attempting to launch payload")
if not len(self._host_db_client.connections):
self._host_db_client.connect()
if len(self._host_db_client.connections):
self._host_db_client.query(self.payload)
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
attack_successful = True
if attack_successful:
self.sys_log.info(f"{self.name}: Payload Successful")
self.attack_stage = RansomwareAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Payload failed")
self.attack_stage = RansomwareAttackStage.FAILED
else:
self.sys_log.error("Attack Attempted to launch too quickly")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download(self):
"""Downloads itself via the onto the local file_system."""
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
if self._local_download_verify():
self.attack_stage = RansomwareAttackStage.DOWNLOAD
else:
self.sys_log.info("Malware failed to create a installation location")
self.attack_stage = RansomwareAttackStage.FAILED
else:
self.sys_log.info("Malware failed to download")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download_verify(self) -> bool:
"""Verifies a download location - Creates one if needed."""
for folder in self.file_system.folders:
if self.file_system.folders[folder].name == "downloads":
self.file_system.num_file_creations += 1
return True
self.file_system.create_folder("downloads")
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
return True

View File

@@ -88,6 +88,10 @@ class SysLog:
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
if to_terminal or SIM_OUTPUT.write_sys_log_to_terminal:
print(f"{self.hostname}: ({level}) {msg}")
def debug(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the DEBUG level.
@@ -97,8 +101,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.debug(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "DEBUG", to_terminal)
def info(self, msg: str, to_terminal: bool = False):
"""
@@ -109,8 +112,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.info(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "INFO", to_terminal)
def warning(self, msg: str, to_terminal: bool = False):
"""
@@ -121,8 +123,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.warning(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "WARNING", to_terminal)
def error(self, msg: str, to_terminal: bool = False):
"""
@@ -133,8 +134,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.error(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "ERROR", to_terminal)
def critical(self, msg: str, to_terminal: bool = False):
"""
@@ -145,5 +145,4 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.critical(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "CRITICAL", to_terminal)

View File

@@ -141,8 +141,7 @@ class DatabaseService(Service):
"""Returns the database file."""
return self.file_system.get_file(folder_name="database", file_name="database.db")
@property
def folder(self) -> Folder:
def _return_database_folder(self) -> Folder:
"""Returns the database folder."""
return self.file_system.get_folder_by_id(self.db_file.folder_id)
@@ -187,7 +186,10 @@ class DatabaseService(Service):
}
def _process_sql(
self, query: Literal["SELECT", "DELETE", "INSERT"], query_id: str, connection_id: Optional[str] = None
self,
query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"],
query_id: str,
connection_id: Optional[str] = None,
) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
@@ -196,6 +198,7 @@ class DatabaseService(Service):
- SELECT : returns the data
- DELETE : deletes the data
- INSERT : inserts the data
- ENCRYPT : corrupts the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
@@ -207,7 +210,15 @@ class DatabaseService(Service):
return {"status_code": 404, "type": "sql", "data": False}
if query == "SELECT":
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {
"status_code": 200,
"type": "sql",
@@ -226,6 +237,20 @@ class DatabaseService(Service):
"uuid": query_id,
"connection_id": connection_id,
}
elif query == "ENCRYPT":
self.file_system.num_file_creations += 1
self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT
self.db_file.num_access += 1
database_folder = self._return_database_folder()
database_folder.health_status = FileSystemItemHealthStatus.CORRUPT
self.file_system.num_file_deletions += 1
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif query == "INSERT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {

View File

@@ -87,13 +87,9 @@ class NTPClient(Service):
:return: True if successful, False otherwise.
"""
if not isinstance(payload, NTPPacket):
_LOGGER.debug(f"{payload} is not a NTPPacket")
_LOGGER.debug(f"{self.name}: Failed to parse NTP update")
return False
if payload.ntp_reply.ntp_datetime:
self.sys_log.info(
f"{self.name}: \
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
)
self.time = payload.ntp_reply.ntp_datetime
return True
@@ -124,5 +120,3 @@ class NTPClient(Service):
if self.operating_state == ServiceOperatingState.RUNNING:
# request time from server
self.request_time()
else:
self.sys_log.debug(f"{self.name} ntp client not running")

View File

@@ -224,6 +224,10 @@ class Software(SimComponent):
if self.health_state_actual == SoftwareHealthState.FIXING:
self._update_fix_status()
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
class IOSoftware(Software):
"""

View File

@@ -303,6 +303,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
@@ -314,6 +316,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -325,6 +329,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -336,6 +342,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -347,6 +355,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -358,6 +368,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -505,23 +517,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

View File

@@ -5,21 +5,7 @@
# -------------- -------------- --------------
#
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true

View File

@@ -4,22 +4,7 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true

View File

@@ -30,21 +30,7 @@
# | external_computer |------| switch_3 |------| external_server |
# ----------------------- -------------- ---------------------
#
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true

View File

@@ -319,6 +319,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
@@ -330,6 +332,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -341,6 +345,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -352,6 +358,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -363,6 +371,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -374,6 +384,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -521,23 +533,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

View File

@@ -106,25 +106,6 @@ agents:
label: ICS
options: {}
# observation_space:
# type: UC2BlueObservation
# options:
# num_services_per_node: 1
# num_folders_per_node: 1
# num_files_per_folder: 1
# num_nics_per_node: 2
# nodes:
# - node_hostname: client_1
# links:
# - link_ref: client_1___switch_1
# acl:
# options:
# max_acl_rules: 10
# router_hostname: router_1
# ip_address_order:
# - node_hostname: client_1
# nic_num: 1
# ics: null
action_space:
action_list:
- type: DONOTHING
@@ -149,6 +130,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
2:
action: FIREWALL_ACL_REMOVERULE
options:
@@ -169,6 +152,8 @@ agents:
source_port_id: 2
dest_port_id: 3
protocol_id: 2
source_wildcard_id: 0
dest_wildcard_id: 0
4:
action: FIREWALL_ACL_REMOVERULE
options:
@@ -189,6 +174,8 @@ agents:
source_port_id: 4
dest_port_id: 4
protocol_id: 4
source_wildcard_id: 0
dest_wildcard_id: 0
6:
action: FIREWALL_ACL_REMOVERULE
options:
@@ -209,6 +196,8 @@ agents:
source_port_id: 4
dest_port_id: 4
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
8:
action: FIREWALL_ACL_REMOVERULE
options:
@@ -229,6 +218,8 @@ agents:
source_port_id: 5
dest_port_id: 5
protocol_id: 2
source_wildcard_id: 0
dest_wildcard_id: 0
10:
action: FIREWALL_ACL_REMOVERULE
options:
@@ -249,6 +240,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
12:
action: FIREWALL_ACL_REMOVERULE
options:
@@ -271,13 +264,10 @@ agents:
- node_name: client_1
- node_name: dmz_server
- node_name: external_computer
ip_address_order:
- node_name: client_1
nic_num: 1
- node_name: dmz_server
nic_num: 1
- node_name: external_computer
nic_num: 1
ip_list:
- 192.168.0.10
- 192.168.10.10
- 192.168.20.10
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2

View File

@@ -314,6 +314,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
@@ -325,6 +327,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -336,6 +340,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -347,6 +353,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -358,6 +366,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -369,6 +379,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -516,23 +528,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
@@ -780,6 +784,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
@@ -791,6 +797,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -802,6 +810,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -813,6 +823,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -824,6 +836,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -835,6 +849,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -981,23 +997,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

View File

@@ -1,18 +1,4 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true

View File

@@ -131,10 +131,6 @@ agents:
options:
node_hostname: client_1
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
@@ -490,6 +486,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
@@ -501,6 +499,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -512,6 +512,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -523,6 +525,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
@@ -534,6 +538,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
@@ -545,6 +551,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -703,23 +711,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:

View File

@@ -493,6 +493,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
@@ -504,6 +506,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -515,6 +519,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -526,6 +532,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
@@ -537,6 +545,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
@@ -548,6 +558,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -729,23 +741,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:

View File

@@ -327,6 +327,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
@@ -338,6 +340,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -349,6 +353,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -360,6 +366,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -371,6 +379,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -382,6 +392,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -528,23 +540,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

View File

@@ -327,6 +327,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
@@ -338,6 +340,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -349,6 +353,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
@@ -360,6 +366,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -371,6 +379,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
@@ -382,6 +392,8 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
@@ -528,23 +540,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

View File

@@ -50,8 +50,8 @@ def set_syslog_output_to_true():
"path",
Path(TEST_ASSETS_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")),
)
monkeypatch.setattr(SIM_OUTPUT, "save_pcap_logs", True)
monkeypatch.setattr(SIM_OUTPUT, "save_sys_logs", True)
monkeypatch.setattr(SIM_OUTPUT, "save_pcap_logs", False)
monkeypatch.setattr(SIM_OUTPUT, "save_sys_logs", False)
yield
@@ -529,7 +529,7 @@ def game_and_agent():
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
ip_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = ObservationManager(NestedObservation(components={}))

View File

@@ -130,6 +130,8 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
},
)
agent.store_action(action)
@@ -155,6 +157,8 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
},
)
agent.store_action(action)

View File

@@ -22,10 +22,13 @@ def test_capture_nmne(uc2_network):
web_server_nic = web_server.network_interface[1]
db_server_nic = db_server.network_interface[1]
# Set the NMNE configuration to capture DELETE queries as MNEs
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
@@ -63,6 +66,20 @@ def test_capture_nmne(uc2_network):
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 2}}}}
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 2}}}}
# Perform an "ENCRYPT" query
db_client.query("ENCRYPT")
# Check that the web server and database server interfaces register an additional MNE
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 3}}}}
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
# Perform another "SELECT" query
db_client.query("SELECT")
# Check that no additional MNEs are captured
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 3}}}}
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
def test_describe_state_nmne(uc2_network):
"""
@@ -70,7 +87,7 @@ def test_describe_state_nmne(uc2_network):
This test involves a web server querying a database server and checks if the MNEs are captured
based on predefined keywords in the network configuration. Specifically, it checks the capture
of the "DELETE" SQL command as a malicious network event. It also checks that running describe_state
of the "DELETE" / "ENCRYPT" SQL commands as a malicious network event. It also checks that running describe_state
only shows MNEs since the last time describe_state was called.
"""
web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa
@@ -82,10 +99,13 @@ def test_describe_state_nmne(uc2_network):
web_server_nic = web_server.network_interface[1]
db_server_nic = db_server.network_interface[1]
# Set the NMNE configuration to capture DELETE queries as MNEs
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # "DELETE" & "ENCRYPT" SQL commands as a keywords for MNE detection
}
# Apply the NMNE configuration settings
@@ -138,6 +158,36 @@ def test_describe_state_nmne(uc2_network):
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}}
# Perform a "ENCRYPT" query
db_client.query("ENCRYPT")
# Check that the web server's outbound interface and the database server's inbound interface register the MNE
web_server_nic_state = web_server_nic.describe_state()
db_server_nic_state = db_server_nic.describe_state()
uc2_network.apply_timestep(timestep=0)
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 3}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 3}}}}
# Perform another "SELECT" query
db_client.query("SELECT")
# Check that no additional MNEs are captured
web_server_nic_state = web_server_nic.describe_state()
db_server_nic_state = db_server_nic.describe_state()
uc2_network.apply_timestep(timestep=0)
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 3}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 3}}}}
# Perform another "ENCRYPT"
db_client.query("ENCRYPT")
# Check that the web server and database server interfaces register an additional MNE
web_server_nic_state = web_server_nic.describe_state()
db_server_nic_state = db_server_nic.describe_state()
uc2_network.apply_timestep(timestep=0)
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 4}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
def test_capture_nmne_observations(uc2_network):
"""
@@ -146,7 +196,7 @@ def test_capture_nmne_observations(uc2_network):
This test ensures the observation space, as defined by instances of NICObservation, accurately reflects the
number of MNEs detected based on network activities over multiple iterations.
The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update
The test employs a series of "DELETE" and "ENCRYPT" SQL operations, considered as MNEs, to validate the dynamic update
and accuracy of the observation space related to network interface conditions. It confirms that the
observed NIC states match expected MNE activity levels.
"""
@@ -158,10 +208,13 @@ def test_capture_nmne_observations(uc2_network):
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
# Set the NMNE configuration to capture DELETE queries as MNEs
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # Specify "DELETE" & "ENCRYPT" SQL commands as a keywords for MNE detection
}
# Apply the NMNE configuration settings
@@ -196,3 +249,28 @@ def test_capture_nmne_observations(uc2_network):
assert web_nic_obs["outbound"] == expected_nmne
assert db_nic_obs["inbound"] == expected_nmne
uc2_network.apply_timestep(timestep=0)
for i in range(0, 20):
# Perform a "ENCRYPT" query each iteration
for j in range(i):
db_client.query("ENCRYPT")
# Observe the current state of NMNEs from the NICs of both the database and web servers
state = sim.describe_state()
db_nic_obs = db_server_nic_obs.observe(state)["NMNE"]
web_nic_obs = web_server_nic_obs.observe(state)["NMNE"]
# Define expected NMNE values based on the iteration count
if i > 10:
expected_nmne = 3 # High level of detected MNEs after 10 iterations
elif i > 5:
expected_nmne = 2 # Moderate level after more than 5 iterations
elif i > 0:
expected_nmne = 1 # Low level detected after just starting
else:
expected_nmne = 0 # No MNEs detected
# Assert that the observed NMNEs match the expected values for both NICs
assert web_nic_obs["outbound"] == expected_nmne
assert db_nic_obs["inbound"] == expected_nmne
uc2_network.apply_timestep(timestep=0)

View File

@@ -152,6 +152,22 @@ def test_with_routes_can_ping(multi_hop_network):
assert pc_a.ping(pc_b.network_interface[1].ip_address)
def test_with_default_routes_can_ping(multi_hop_network):
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
pc_b = multi_hop_network.get_node_by_hostname("pc_b")
router_1: Router = multi_hop_network.get_node_by_hostname("router_1") # noqa
router_2: Router = multi_hop_network.get_node_by_hostname("router_2") # noqa
# Configure Route from Router 1 to PC B subnet
router_1.route_table.set_default_route_next_hop_ip_address("192.168.1.2")
# Configure Route from Router 2 to PC A subnet
router_2.route_table.set_default_route_next_hop_ip_address("192.168.1.1")
assert pc_a.ping(pc_b.network_interface[1].ip_address)
def test_ping_router_port_multi_hop(multi_hop_network):
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
router_2 = multi_hop_network.get_node_by_hostname("router_2")

View File

@@ -73,6 +73,7 @@ def dos_bot_db_server_green_client(example_network) -> Network:
return network
@pytest.mark.xfail(reason="Tests fail due to recent changes in how DB connections are handled for example layout.")
def test_repeating_dos_attack(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
@@ -104,6 +105,7 @@ def test_repeating_dos_attack(dos_bot_and_db_server):
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
@pytest.mark.xfail(reason="Tests fail due to recent changes in how DB connections are handled for example layout.")
def test_non_repeating_dos_attack(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
@@ -135,6 +137,7 @@ def test_non_repeating_dos_attack(dos_bot_and_db_server):
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
@pytest.mark.xfail(reason="Tests fail due to recent changes in how DB connections are handled for example layout.")
def test_dos_bot_database_service_connection(dos_bot_and_db_server):
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
@@ -147,6 +150,7 @@ def test_dos_bot_database_service_connection(dos_bot_and_db_server):
assert len(dos_bot.connections) == db_server_service.max_sessions
@pytest.mark.xfail(reason="Tests fail due to recent changes in how DB connections are handled for example layout.")
def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client):
network: Network = dos_bot_db_server_green_client

View File

@@ -0,0 +1,163 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.ransomware_script import (
RansomwareAttackStage,
RansomwareScript,
)
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.software import SoftwareHealthState
@pytest.fixture(scope="function")
def ransomware_script_and_db_server(client_server) -> Tuple[RansomwareScript, Computer, DatabaseService, Server]:
computer, server = client_server
# install db client on computer
computer.software_manager.install(DatabaseClient)
db_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient")
db_client.run()
# Install DoSBot on computer
computer.software_manager.install(RansomwareScript)
ransomware_script_application: RansomwareScript = computer.software_manager.software.get("RansomwareScript")
ransomware_script_application.configure(
server_ip_address=IPv4Address(server.network_interface[1].ip_address), payload="ENCRYPT"
)
# Install DB Server service on server
server.software_manager.install(DatabaseService)
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_server_service.start()
return ransomware_script_application, computer, db_server_service, server
@pytest.fixture(scope="function")
def ransomware_script_db_server_green_client(example_network) -> Network:
network: Network = example_network
router_1: Router = example_network.get_node_by_hostname("router_1")
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
client_1: Computer = network.get_node_by_hostname("client_1")
client_2: Computer = network.get_node_by_hostname("client_2")
server: Server = network.get_node_by_hostname("server_1")
# install db client on client 1
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client.run()
# install Ransomware Script bot on client 1
client_1.software_manager.install(RansomwareScript)
ransomware_script_application: RansomwareScript = client_1.software_manager.software.get("RansomwareScript")
ransomware_script_application.configure(
server_ip_address=IPv4Address(server.network_interface[1].ip_address), payload="ENCRYPT"
)
# install db server service on server
server.software_manager.install(DatabaseService)
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
db_server_service.start()
# Install DB client (green) on client 2
client_2.software_manager.install(DatabaseClient)
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address(server.network_interface[1].ip_address))
database_client.run()
return network
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
"""Test a repeating data manipulation attack."""
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
assert computer.file_system.num_file_creations == 0
RansomwareScript.target_scan_p_of_success = 1
RansomwareScript.c2_beacon_p_of_success = 1
RansomwareScript.ransomware_encrypt_p_of_success = 1
RansomwareScript.repeat = True
RansomwareScript.attack()
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
assert computer.file_system.num_file_creations == 1
computer.apply_timestep(timestep=1)
server.apply_timestep(timestep=1)
assert RansomwareScript.attack_stage == RansomwareAttackStage.NOT_STARTED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
def test_repeating_ransomware_script_attack(ransomware_script_and_db_server):
"""Test a repeating ransowmare script attack."""
RansomwareScript, computer, db_server_service, server = ransomware_script_and_db_server
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
RansomwareScript.target_scan_p_of_success = 1
RansomwareScript.c2_beacon_p_of_success = 1
RansomwareScript.ransomware_encrypt_p_of_success = 1
RansomwareScript.repeat = False
RansomwareScript.attack()
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert computer.file_system.num_file_creations == 1
computer.apply_timestep(timestep=1)
computer.pre_timestep(timestep=1)
server.apply_timestep(timestep=1)
server.pre_timestep(timestep=1)
assert RansomwareScript.attack_stage == RansomwareAttackStage.SUCCEEDED
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert computer.file_system.num_file_creations == 0
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
"""Test to see show that the database service still operate"""
network: Network = ransomware_script_db_server_green_client
client_1: Computer = network.get_node_by_hostname("client_1")
ransomware_script_application: RansomwareScript = client_1.software_manager.software.get("RansomwareScript")
client_2: Computer = network.get_node_by_hostname("client_2")
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
server: Server = network.get_node_by_hostname("server_1")
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
assert green_db_client.query("SELECT")
assert green_db_client.last_query_response.get("status_code") == 200
ransomware_script_application.target_scan_p_of_success = 1
ransomware_script_application.ransomware_encrypt_p_of_success = 1
ransomware_script_application.c2_beacon_p_of_success = 1
ransomware_script_application.repeat = False
ransomware_script_application.attack()
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert green_db_client.query("SELECT") is True
assert green_db_client.last_query_response.get("status_code") == 200

View File

@@ -21,6 +21,7 @@ def test_create_folder_and_file(file_system):
assert file_system.get_folder("test_folder").get_file("test_file.txt")
file_system.apply_timestep(0)
file_system.pre_timestep(0)
# num file creations should reset
assert file_system.num_file_creations == 0
@@ -38,6 +39,7 @@ def test_create_file_no_folder(file_system):
assert file_system.get_folder("root").get_file("test_file.txt").size == 10
file_system.apply_timestep(0)
file_system.pre_timestep(0)
# num file creations should reset
assert file_system.num_file_creations == 0
@@ -59,6 +61,7 @@ def test_delete_file(file_system):
assert len(file_system.get_folder("root").deleted_files) == 1
file_system.apply_timestep(0)
file_system.pre_timestep(0)
# num file deletions should reset
assert file_system.num_file_deletions == 0
@@ -174,6 +177,7 @@ def test_move_file(file_system):
assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid
file_system.apply_timestep(0)
file_system.pre_timestep(0)
# num file creations and deletions should reset
assert file_system.num_file_creations == 0
@@ -203,6 +207,7 @@ def test_copy_file(file_system):
assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid
file_system.apply_timestep(0)
file_system.pre_timestep(0)
# num file creations should reset
assert file_system.num_file_creations == 0