Merge branch 'dev' into bugfix/2455-notebook_updates
This commit is contained in:
@@ -1 +1 @@
|
||||
3.0.0b7
|
||||
3.0.0b8
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
@@ -490,6 +478,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -501,6 +491,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -512,6 +504,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -523,6 +517,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
50: # old action num: 26
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -534,6 +530,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
51: # old action num: 27
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -545,6 +543,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
52: # old action num: 28
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
@@ -703,23 +703,15 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
@@ -730,10 +722,12 @@ agents:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_1_green_user
|
||||
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
|
||||
@@ -492,6 +492,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -503,6 +505,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -514,6 +518,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -525,6 +531,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
50: # old action num: 26
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -536,6 +544,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
51: # old action num: 27
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
@@ -547,6 +557,8 @@ agents:
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
52: # old action num: 28
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
@@ -704,23 +716,15 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
@@ -1284,23 +1288,15 @@ agents:
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -487,7 +487,9 @@ class RouterACLAddRuleAction(AbstractAction):
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
source_wildcard_id: int,
|
||||
dest_ip_id: int,
|
||||
dest_wildcard_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
@@ -519,7 +521,7 @@ class RouterACLAddRuleAction(AbstractAction):
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
|
||||
if source_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif source_port_id == 1:
|
||||
@@ -528,13 +530,14 @@ class RouterACLAddRuleAction(AbstractAction):
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_ip_id == 0:
|
||||
if dest_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
|
||||
|
||||
if dest_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
@@ -553,8 +556,10 @@ class RouterACLAddRuleAction(AbstractAction):
|
||||
permission_str,
|
||||
protocol,
|
||||
str(src_ip),
|
||||
src_wildcard,
|
||||
src_port,
|
||||
str(dst_ip),
|
||||
dst_wildcard,
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
@@ -624,7 +629,9 @@ class FirewallACLAddRuleAction(AbstractAction):
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
source_wildcard_id: int,
|
||||
dest_ip_id: int,
|
||||
dest_wildcard_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
@@ -665,7 +672,7 @@ class FirewallACLAddRuleAction(AbstractAction):
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_ip_id == 0:
|
||||
if dest_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
@@ -680,6 +687,8 @@ class FirewallACLAddRuleAction(AbstractAction):
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
|
||||
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
|
||||
|
||||
return [
|
||||
"network",
|
||||
@@ -692,8 +701,10 @@ class FirewallACLAddRuleAction(AbstractAction):
|
||||
permission_str,
|
||||
protocol,
|
||||
str(src_ip),
|
||||
src_wildcard,
|
||||
src_port,
|
||||
str(dst_ip),
|
||||
dst_wildcard,
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
@@ -871,7 +882,8 @@ class ActionManager:
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
|
||||
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
ip_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
wildcard_list: List[str] = [], # to allow mapping from wildcard index to
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
@@ -897,8 +909,8 @@ class ActionManager:
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_address_list: Optional[List[str]]
|
||||
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
@@ -959,8 +971,10 @@ class ActionManager:
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str] = ip_address_list
|
||||
|
||||
self.ip_address_list: List[str] = ip_list
|
||||
self.wildcard_list: List[str] = wildcard_list
|
||||
if self.wildcard_list == []:
|
||||
self.wildcard_list = ["NONE"]
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(self.node_names),
|
||||
@@ -1195,6 +1209,24 @@ class ActionManager:
|
||||
raise RuntimeError(msg)
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
|
||||
"""
|
||||
Get the IP wildcard corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP wildcard to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The wildcard address.
|
||||
:rtype: str
|
||||
"""
|
||||
if wildcard_idx >= len(self.wildcard_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
|
||||
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.wildcard_list[wildcard_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
@@ -1253,37 +1285,14 @@ class ActionManager:
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
|
||||
# the nodes in the simulation.
|
||||
# TODO: refactor. Options:
|
||||
# 1: This should be pulled out into it's own function for clarity
|
||||
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
|
||||
# go through the nodes here.
|
||||
ip_address_order = cfg["options"].pop("ip_address_order", {})
|
||||
ip_address_list = []
|
||||
for entry in ip_address_order:
|
||||
node_name = entry["node_name"]
|
||||
nic_num = entry["nic_num"]
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_name)
|
||||
ip_address = node_obj.network_interface[nic_num].ip_address
|
||||
ip_address_list.append(ip_address)
|
||||
|
||||
if not ip_address_list:
|
||||
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
|
||||
for node_name in node_names:
|
||||
node_obj = game.simulation.network.get_node_by_hostname(node_name)
|
||||
if node_obj is None:
|
||||
continue
|
||||
network_interfaces = node_obj.network_interfaces
|
||||
for nic_uuid, nic_obj in network_interfaces.items():
|
||||
ip_address_list.append(nic_obj.ip_address)
|
||||
if "ip_list" not in cfg["options"]:
|
||||
cfg["options"]["ip_list"] = []
|
||||
|
||||
obj = cls(
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
ip_address_list=ip_address_list,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +43,10 @@ class LinkObservation(AbstractObservation, identifier="LINK"):
|
||||
"""
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
self.where[-1] = "<->".join(self.where[-1].split("<->")[::-1]) # try swapping endpoint A and B
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
|
||||
@@ -189,7 +189,6 @@ class ObservationManager:
|
||||
"""
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
print(config)
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
|
||||
|
||||
@@ -26,7 +26,7 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
WhereType = Iterable[str | int] | None
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
@@ -293,6 +294,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
return 0.0
|
||||
last_connection_successful = db_state["last_connection_successful"]
|
||||
if last_connection_successful is False:
|
||||
return -1.0
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from typing import Dict, Tuple
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
@@ -19,3 +24,60 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class PeriodicAgent(AbstractScriptedAgent):
|
||||
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
start_step: int = 20
|
||||
"The timestep at which an agent begins performing it's actions."
|
||||
start_variance: int = 5
|
||||
"Deviation around the start step."
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions."
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to."
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: ActionManager,
|
||||
observation_space: ObservationManager,
|
||||
reward_function: RewardFunction,
|
||||
settings: Optional[Settings] = None,
|
||||
) -> None:
|
||||
"""Initialise PeriodicAgent."""
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.settings = settings or PeriodicAgent.Settings()
|
||||
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
|
||||
self.num_executions = 0
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep when the next execute action should be taken.
|
||||
:type timestep: int
|
||||
:param variance: Uniform random variance applied to the timestep
|
||||
:type variance: int
|
||||
"""
|
||||
random_increment = random.randint(-variance, variance)
|
||||
self.next_execution_timestep = timestep + random_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
|
||||
self.num_executions += 1
|
||||
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
|
||||
return "DONOTHING", {}
|
||||
|
||||
78
src/primaite/game/agent/scripted_agents/tap001.py
Normal file
78
src/primaite/game/agent/scripted_agents/tap001.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class TAP001(AbstractScriptedAgent):
|
||||
"""
|
||||
TAP001 | Mobile Malware -- Ransomware Variant.
|
||||
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
installed: bool = False
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
|
||||
This application acts a wrapper around the kill-chain, similar to green-analyst and
|
||||
the previous UC2 data manipulation bot.
|
||||
|
||||
:param obs: Current observation for this agent.
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
if not self.installed:
|
||||
self.installed = True
|
||||
return "NODE_APPLICATION_INSTALL", {
|
||||
"node_id": self.starting_node_idx,
|
||||
"application_name": "RansomwareScript",
|
||||
"ip_address": self.ip_address,
|
||||
}
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
for n, act in self.action_manager.action_map.items():
|
||||
if not act[0] == "NODE_APPLICATION_INSTALL":
|
||||
continue
|
||||
if act[1]["node_id"] == self.starting_node_idx:
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
@@ -11,7 +11,10 @@ from primaite.game.agent.observations.observation_manager import ObservationMana
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator.network.airspace import AIR_SPACE
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -26,6 +29,7 @@ from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
@@ -43,6 +47,7 @@ APPLICATION_TYPES_MAPPING = {
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
"DoSBot": DoSBot,
|
||||
"RansomwareScript": RansomwareScript,
|
||||
}
|
||||
"""List of available applications that can be installed on nodes in the PrimAITE Simulation."""
|
||||
|
||||
@@ -128,6 +133,8 @@ class PrimaiteGame:
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
self.pre_timestep()
|
||||
|
||||
if self.step_counter == 0:
|
||||
state = self.get_sim_state()
|
||||
for agent in self.agents.values():
|
||||
@@ -172,6 +179,10 @@ class PrimaiteGame:
|
||||
response=response,
|
||||
)
|
||||
|
||||
def pre_timestep(self) -> None:
|
||||
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
|
||||
self.simulation.pre_timestep(self.step_counter)
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
self.step_counter += 1
|
||||
@@ -211,6 +222,7 @@ class PrimaiteGame:
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
AIR_SPACE.clear()
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
|
||||
@@ -268,6 +280,9 @@ class PrimaiteGame:
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
operating_state=NodeOperatingState.ON
|
||||
if not (p := node_cfg.get("operating_state"))
|
||||
else NodeOperatingState[p.upper()],
|
||||
)
|
||||
else:
|
||||
msg = f"invalid node type {n_type} in config"
|
||||
@@ -339,6 +354,19 @@ class PrimaiteGame:
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
)
|
||||
elif application_type == "RansomwareScript":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
|
||||
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
|
||||
ransomware_encrypt_p_of_success=float(
|
||||
opt.get("ransomware_encrypt_p_of_success", "0.1")
|
||||
),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
@@ -383,6 +411,7 @@ class PrimaiteGame:
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"])
|
||||
node_b = net.get_node_by_hostname(link_cfg["endpoint_b_hostname"])
|
||||
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
@@ -423,6 +452,16 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
elif agent_type == "PeriodicAgent":
|
||||
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
|
||||
new_agent = PeriodicAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
elif agent_type == "ProxyAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = ProxyAgent(
|
||||
@@ -443,6 +482,15 @@ class PrimaiteGame:
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
elif agent_type == "TAP001":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = TAP001(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
else:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
|
||||
@@ -26,6 +26,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def __init__(self, game_config: Dict):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
@@ -49,6 +52,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
step = self.game.step_counter
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
@@ -224,6 +228,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
|
||||
@@ -29,10 +29,12 @@ class PrimaiteIO:
|
||||
"""Whether to save a log of all agents' actions every step."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
save_pcap_logs: bool = False
|
||||
save_pcap_logs: bool = True
|
||||
"""Whether to save PCAP logs."""
|
||||
save_sys_logs: bool = False
|
||||
save_sys_logs: bool = True
|
||||
"""Whether to save system logs."""
|
||||
write_sys_log_to_terminal: bool = False
|
||||
"""Whether to write the sys log to the terminal."""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None) -> None:
|
||||
"""
|
||||
@@ -47,6 +49,7 @@ class PrimaiteIO:
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
@@ -93,4 +96,5 @@ class PrimaiteIO:
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
"""Create an instance of PrimaiteIO based on a configuration dict."""
|
||||
new = cls(settings=cls.Settings(**config))
|
||||
|
||||
return new
|
||||
|
||||
@@ -14,6 +14,7 @@ class _SimOutput:
|
||||
)
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
self.write_sys_log_to_terminal: bool = False
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
|
||||
@@ -258,8 +258,7 @@
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "03b2013a-b7d1-47ee-b08c-8dab83833720",
|
||||
"id": "0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PrimAITE Router Simulation Demo\n",
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "c8bb5698-e746-4e90-9c2f-efe962acdfa0",
|
||||
"id": "1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" +------------+\n",
|
||||
@@ -48,7 +48,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "415d487c-6457-497d-85d6-99439b3541e7",
|
||||
"id": "2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The Network\n",
|
||||
@@ -60,7 +60,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de57ac8c-5b28-4847-a759-2ceaf5593329",
|
||||
"id": "3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -72,7 +72,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a1e2e4df-67c0-4584-ab27-47e2c7c7fcd2",
|
||||
"id": "4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -83,7 +83,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fb052c56-e9ca-4093-9115-d0c440b5ff53",
|
||||
"id": "5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Most of the Network components have a `.show()` function that prints a table of information about that object. We can view the Nodes and Links on the Network by calling `network.show()`."
|
||||
@@ -92,7 +92,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc199741-ef2e-47f5-b2f0-e20049ccf40f",
|
||||
"id": "6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -103,7 +103,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "76d2b7e9-280b-4741-a8b3-a84bed219fac",
|
||||
"id": "7",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -115,7 +115,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84113002-843e-4cab-b899-667b50f25f6b",
|
||||
"id": "8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Router Nodes\n",
|
||||
@@ -125,7 +125,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf63a178-eee5-4669-bf64-13aea7ecf6cb",
|
||||
"id": "9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.show()` displays the Ethernet interfaces on the Router. If you need a table in markdown format, pass `markdown=True`."
|
||||
@@ -134,7 +134,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e76d1854-961e-438c-b40f-77fd9c3abe38",
|
||||
"id": "10",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -145,7 +145,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e000540c-687c-4254-870c-1d814603bdbf",
|
||||
"id": "11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.arp.show()` displays the Router ARP Cache."
|
||||
@@ -154,7 +154,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "92de8b42-92d7-4934-9c12-50bf724c9eb2",
|
||||
"id": "12",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -165,7 +165,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a9ff7ee8-9482-44de-9039-b684866bdc82",
|
||||
"id": "13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.acl.show()` displays the Access Control List."
|
||||
@@ -174,7 +174,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5922282a-d22b-4e55-9176-f3f3654c849f",
|
||||
"id": "14",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -185,7 +185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "71c87884-f793-4c9f-b004-5b0df86cf585",
|
||||
"id": "15",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.router_table.show()` displays the static routes the Router provides."
|
||||
@@ -194,7 +194,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "327203be-f475-4727-82a1-e992d3b70ed8",
|
||||
"id": "16",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -205,7 +205,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eef561a8-3d39-4c8b-bbc8-e8b10b8ed25f",
|
||||
"id": "17",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `router.sys_log.show()` displays the Router system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
@@ -214,7 +214,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3d0aa004-b10c-445f-aaab-340e0e716c74",
|
||||
"id": "18",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -225,7 +225,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25630c90-c54e-4b5d-8bf4-ad1b0722e126",
|
||||
"id": "19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Switch Nodes\n",
|
||||
@@ -235,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4879394d-2981-40de-a229-e19b09a34e6e",
|
||||
"id": "20",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.show()` displays the Switch orts on the Switch."
|
||||
@@ -244,7 +244,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7fd439b-5442-4e9d-9e7d-86dacb77f458",
|
||||
"id": "21",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -255,7 +255,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beb8dbd6-7250-4ac9-9fa2-d2a9c0e5fd19",
|
||||
"id": "22",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -266,7 +266,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d06e1310-4a77-4315-a59f-cb1b49ca2352",
|
||||
"id": "23",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -277,7 +277,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fda75ac3-8123-4234-8f36-86547891d8df",
|
||||
"id": "24",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
@@ -286,7 +286,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0d984b7-a7c1-4bbd-aa5a-9d3caecb08dc",
|
||||
"id": "25",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -297,7 +297,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f1d99ad-db4f-4baf-8a35-e1d95f269586",
|
||||
"id": "26",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Computer/Server Nodes\n",
|
||||
@@ -307,7 +307,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c9e2251a-1b47-46e5-840f-7fec3e39c5aa",
|
||||
"id": "27",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -318,7 +318,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "656c37f6-b145-42af-9714-8d2886d0eff8",
|
||||
"id": "28",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -329,7 +329,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f1097a49-a3da-4d79-a06d-ae8af452918f",
|
||||
"id": "29",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
|
||||
@@ -338,7 +338,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "66b267d6-2308-486a-b9aa-cb8d3bcf0753",
|
||||
"id": "30",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -349,7 +349,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0d1fcad8-5b1a-4d8b-a49f-aa54a95fcaf0",
|
||||
"id": "31",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
@@ -358,7 +358,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b5debe8-ef1b-445d-8fa9-6a45568f21f3",
|
||||
"id": "32",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -369,7 +369,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fcfa1773-798c-4ada-9318-c3ad928217da",
|
||||
"id": "33",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Network Comms Check\n",
|
||||
@@ -380,7 +380,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "495b7de4-b6ce-41a6-9114-f74752ab4491",
|
||||
"id": "34",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -391,7 +391,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3e13922a-217f-4f4e-99b6-57a07613cade",
|
||||
"id": "35",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll first ping client_1's default gateway."
|
||||
@@ -400,7 +400,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a38abb71-994e-49e8-8f51-e9a550e95b99",
|
||||
"id": "36",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -412,7 +412,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8388e1e9-30e3-4534-8e5a-c6e9144149d2",
|
||||
"id": "37",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -423,7 +423,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "02c76d5c-d954-49db-912d-cb9c52f46375",
|
||||
"id": "38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
|
||||
@@ -432,7 +432,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff8e976a-c16b-470c-8923-325713a30d6c",
|
||||
"id": "39",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -443,7 +443,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "80280404-a5ab-452f-8a02-771a0d7496b1",
|
||||
"id": "40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And finally, we'll ping the web server."
|
||||
@@ -452,7 +452,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4163f8d-6a72-410c-9f5c-4f881b7de45e",
|
||||
"id": "41",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -463,7 +463,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1194c045-ba77-4427-be30-ed7b5b224850",
|
||||
"id": "42",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
|
||||
@@ -472,7 +472,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e79a523a-5780-45b6-8798-c434e0e522bd",
|
||||
"id": "43",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -483,7 +483,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5928f6dd-1006-45e3-99f3-8f311a875faa",
|
||||
"id": "44",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced Network Usage\n",
|
||||
@@ -493,7 +493,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5e023ef3-7d18-4006-96ee-042a06a481fc",
|
||||
"id": "45",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
|
||||
@@ -502,7 +502,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "603cf913-e261-49da-a7dd-85e1bb6dec56",
|
||||
"id": "46",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -513,7 +513,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5cf962a4-20e6-44ae-9748-7fc5267ae111",
|
||||
"id": "47",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
|
||||
@@ -522,7 +522,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e047de00-3de4-4823-b26a-2c8d64c7a663",
|
||||
"id": "48",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -533,7 +533,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bdc4741d-6e3e-4aec-a69c-c2e9653bd02c",
|
||||
"id": "49",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we'll add an ACL to block ICMP from 192.168.10.22"
|
||||
@@ -542,7 +542,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6db355ae-b99a-441b-a2c4-4ffe78f46bff",
|
||||
"id": "50",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -562,7 +562,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a345e000-8842-4827-af96-adc0fbe390fb",
|
||||
"id": "51",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -573,7 +573,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3a5bfd9f-04cb-493e-a86c-cd268563a262",
|
||||
"id": "52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we attempt (and fail) to ping the web server"
|
||||
@@ -582,7 +582,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a4f4ff31-590f-40fb-b13d-efaa8c2720b6",
|
||||
"id": "53",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -593,7 +593,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "83e56497-097b-45cb-964e-b15c72547b38",
|
||||
"id": "54",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check that the ping was actually sent by client_2 by viewing the sys log"
|
||||
@@ -602,7 +602,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f62b8a4e-fd3b-4059-b108-3d4a0b18f2a0",
|
||||
"id": "55",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -613,7 +613,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c7040311-a879-4620-86a0-55d0774156e5",
|
||||
"id": "56",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check the router sys log to see why the traffic was blocked"
|
||||
@@ -622,7 +622,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e53d776-99da-4d2c-a2a7-bd7ce27bff4c",
|
||||
"id": "57",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -633,7 +633,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aba0bc7d-da57-477b-b34a-3688b5aab2c6",
|
||||
"id": "58",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now a final check to ensure that client_1 can still ping the web_server."
|
||||
@@ -642,7 +642,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d542734b-7582-4af7-8254-bda3de50d091",
|
||||
"id": "59",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -654,7 +654,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d78e9fe3-02c6-4792-944f-5622e26e0412",
|
||||
"id": "60",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
|
||||
@@ -226,6 +226,15 @@ class SimComponent(BaseModel):
|
||||
return
|
||||
return self._request_manager(request, context)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply any logic that needs to happen at the beginning of the timestep to ensure correct observations/rewards.
|
||||
|
||||
:param timestep: what's the current time
|
||||
:type timestep: int
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
|
||||
@@ -103,6 +103,10 @@ class File(FileSystemItemABC):
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
# reset the number of accesses to 0
|
||||
self.num_access = 0
|
||||
|
||||
|
||||
@@ -427,15 +427,21 @@ class FileSystem(SimComponent):
|
||||
"""Apply time step to FileSystem and its child folders and files."""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
# apply timestep to folders
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
# reset number of file creations
|
||||
self.num_file_creations = 0
|
||||
|
||||
# reset number of file deletions
|
||||
self.num_file_deletions = 0
|
||||
|
||||
# apply timestep to folders
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].apply_timestep(timestep=timestep)
|
||||
for folder in self.folders.values():
|
||||
folder.pre_timestep(timestep)
|
||||
|
||||
###############################################################
|
||||
# Agent actions
|
||||
|
||||
@@ -128,6 +128,13 @@ class Folder(FileSystemItemABC):
|
||||
for file_id in self.files:
|
||||
self.files[file_id].apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
for file in self.files.values():
|
||||
file.pre_timestep(timestep)
|
||||
|
||||
def _scan_timestep(self) -> None:
|
||||
"""Apply the scan action timestep."""
|
||||
if self.scan_countdown >= 0:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -86,6 +87,16 @@ class Network(SimComponent):
|
||||
for link_id in self.links:
|
||||
self.links[link_id].apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
for node in self.nodes.values():
|
||||
node.pre_timestep(timestep)
|
||||
|
||||
for link in self.links.values():
|
||||
link.pre_timestep(timestep)
|
||||
|
||||
@property
|
||||
def router_nodes(self) -> List[Node]:
|
||||
"""The Routers in the Network."""
|
||||
@@ -163,10 +174,11 @@ class Network(SimComponent):
|
||||
for node in nodes:
|
||||
for i, port in node.network_interface.items():
|
||||
if hasattr(port, "ip_address"):
|
||||
port_str = port.port_name if port.port_name else port.port_num
|
||||
table.add_row(
|
||||
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
|
||||
)
|
||||
if port.ip_address != IPv4Address("127.0.0.1"):
|
||||
port_str = port.port_name if port.port_name else port.port_num
|
||||
table.add_row(
|
||||
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
|
||||
)
|
||||
print(table)
|
||||
|
||||
if links:
|
||||
|
||||
@@ -9,7 +9,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int:
|
||||
"""
|
||||
Calculate the minimum number of network switches required to connect a given number of nodes.
|
||||
|
||||
@@ -18,7 +18,7 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
to accommodate all nodes under this constraint.
|
||||
|
||||
:param num_nodes: The total number of nodes that need to be connected in the network.
|
||||
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
|
||||
:param max_network_interface: The maximum number of ports available on each switch. Defaults to 24.
|
||||
|
||||
:return: The minimum number of switches required to connect all PCs.
|
||||
|
||||
@@ -33,11 +33,11 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
|
||||
3
|
||||
"""
|
||||
# Reduce the effective number of switch ports by 1 to leave space for the router
|
||||
effective_switch_ports = max_switch_ports - 1
|
||||
effective_network_interface = max_network_interface - 1
|
||||
|
||||
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
|
||||
full_switches = num_nodes // effective_switch_ports
|
||||
extra_pcs = num_nodes % effective_switch_ports
|
||||
full_switches = num_nodes // effective_network_interface
|
||||
extra_pcs = num_nodes % effective_network_interface
|
||||
|
||||
# Return the total number of switches required
|
||||
return full_switches + (1 if extra_pcs > 0 else 0)
|
||||
@@ -77,7 +77,7 @@ def create_office_lan(
|
||||
|
||||
# Calculate the required number of switches
|
||||
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
|
||||
effective_switch_ports = 23 # One port less for router connection
|
||||
effective_network_interface = 23 # One port less for router connection
|
||||
if pcs_ip_block_start <= num_of_switches:
|
||||
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
|
||||
|
||||
@@ -116,7 +116,7 @@ def create_office_lan(
|
||||
# Add PCs to the LAN and connect them to switches
|
||||
for i in range(1, num_pcs + 1):
|
||||
# Add a new edge switch if the current one is full
|
||||
if switch_port == effective_switch_ports:
|
||||
if switch_port == effective_network_interface:
|
||||
switch_n += 1
|
||||
switch_port = 0
|
||||
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)
|
||||
|
||||
@@ -264,6 +264,9 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"""
|
||||
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.uuid)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
@@ -661,6 +664,10 @@ class Link(SimComponent):
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""Apply a timestep to the simulation."""
|
||||
super().apply_timestep(timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
self.current_load = 0.0
|
||||
|
||||
|
||||
@@ -895,6 +902,10 @@ class Node(SimComponent):
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
|
||||
return WebBrowser
|
||||
elif application_class_str == "RansomwareScript":
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
|
||||
return RansomwareScript
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -965,12 +976,15 @@ class Node(SimComponent):
|
||||
table.align = "l"
|
||||
table.title = f"{self.hostname} Network Interface Cards"
|
||||
for port, network_interface in self.network_interface.items():
|
||||
ip_address = ""
|
||||
if hasattr(network_interface, "ip_address"):
|
||||
ip_address = f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}"
|
||||
table.add_row(
|
||||
[
|
||||
port,
|
||||
network_interface.__class__.__name__,
|
||||
network_interface.mac_address,
|
||||
f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}",
|
||||
ip_address,
|
||||
network_interface.speed,
|
||||
"Enabled" if network_interface.enabled else "Disabled",
|
||||
]
|
||||
@@ -1071,6 +1085,23 @@ class Node(SimComponent):
|
||||
|
||||
self.file_system.apply_timestep(timestep=timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.pre_timestep(timestep=timestep)
|
||||
|
||||
for process_id in self.processes:
|
||||
self.processes[process_id].pre_timestep(timestep=timestep)
|
||||
|
||||
for service_id in self.services:
|
||||
self.services[service_id].pre_timestep(timestep=timestep)
|
||||
|
||||
for application_id in self.applications:
|
||||
self.applications[application_id].pre_timestep(timestep=timestep)
|
||||
|
||||
self.file_system.pre_timestep(timestep=timestep)
|
||||
|
||||
def scan(self) -> bool:
|
||||
"""
|
||||
Scan the node and all the items within it.
|
||||
@@ -1341,6 +1372,8 @@ class Node(SimComponent):
|
||||
application_instance.configure(target_ip_address=IPv4Address(ip_address))
|
||||
elif application_instance.name == "DataManipulationBot":
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
elif application_instance.name == "RansomwareScript":
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
@@ -599,7 +599,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -612,7 +614,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -625,7 +629,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -638,7 +644,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -651,7 +659,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
@@ -664,7 +674,9 @@ class Firewall(Router):
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("src_ip"),
|
||||
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
|
||||
dst_ip_address=r_cfg.get("dst_ip"),
|
||||
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
|
||||
position=r_num,
|
||||
)
|
||||
|
||||
|
||||
@@ -322,10 +322,12 @@ class AccessControlList(SimComponent):
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]),
|
||||
src_port=None if request[4] == "ALL" else Port[request[4]],
|
||||
dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]),
|
||||
dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]),
|
||||
dst_port=None if request[7] == "ALL" else Port[request[7]],
|
||||
position=int(request[8]),
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -772,6 +774,13 @@ class RouterARP(ARP):
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=is_default_route_attempt,
|
||||
)
|
||||
elif route and route == self.router.route_table.default_route:
|
||||
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
|
||||
return self._get_arp_cache_mac_address(
|
||||
ip_address=self.router.route_table.default_route.next_hop_ip_address,
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=True,
|
||||
)
|
||||
else:
|
||||
if self.router.route_table.default_route:
|
||||
if not is_default_route_attempt:
|
||||
@@ -822,6 +831,12 @@ class RouterARP(ARP):
|
||||
return network_interface
|
||||
|
||||
if not is_reattempt:
|
||||
if self.router.ip_is_in_router_interface_subnet(ip_address):
|
||||
self.send_arp_request(ip_address)
|
||||
return self._get_arp_cache_network_interface(
|
||||
ip_address=ip_address, is_reattempt=True, is_default_route_attempt=is_default_route_attempt
|
||||
)
|
||||
|
||||
route = self.router.route_table.find_best_route(ip_address)
|
||||
if route and route != self.router.route_table.default_route:
|
||||
self.send_arp_request(route.next_hop_ip_address)
|
||||
@@ -830,6 +845,13 @@ class RouterARP(ARP):
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=is_default_route_attempt,
|
||||
)
|
||||
elif route and route == self.router.route_table.default_route:
|
||||
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
|
||||
return self._get_arp_cache_network_interface(
|
||||
ip_address=self.router.route_table.default_route.next_hop_ip_address,
|
||||
is_reattempt=True,
|
||||
is_default_route_attempt=True,
|
||||
)
|
||||
else:
|
||||
if self.router.route_table.default_route:
|
||||
if not is_default_route_attempt:
|
||||
@@ -1460,6 +1482,8 @@ class Router(NetworkNode):
|
||||
frame.ethernet.src_mac_addr = network_interface.mac_address
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
network_interface.send_frame(frame)
|
||||
else:
|
||||
self.sys_log.error(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
@@ -1540,6 +1564,13 @@ class Router(NetworkNode):
|
||||
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
|
||||
- src_ip_address (str, optional): IP address octet written in base 10
|
||||
- dst_ip_address (str, optional): IP address octet written in base 10
|
||||
- routes (list[dict]): List of route dicts with values:
|
||||
- address (str): The destination address of the route.
|
||||
- subnet_mask (str): The subnet mask of the route.
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
- metric (int): The metric of the route. Optional.
|
||||
- default_route:
|
||||
- next_hop_ip_address (str): The next hop IP for the route.
|
||||
|
||||
Example config:
|
||||
```
|
||||
@@ -1550,6 +1581,10 @@ class Router(NetworkNode):
|
||||
1: {
|
||||
'ip_address' : '192.168.1.1',
|
||||
'subnet_mask' : '255.255.255.0',
|
||||
},
|
||||
2: {
|
||||
'ip_address' : '192.168.0.1',
|
||||
'subnet_mask' : '255.255.255.252',
|
||||
}
|
||||
},
|
||||
'acl' : {
|
||||
@@ -1557,6 +1592,10 @@ class Router(NetworkNode):
|
||||
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
|
||||
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
|
||||
},
|
||||
'routes' : [
|
||||
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
|
||||
],
|
||||
'default_route': {'next_hop_ip_address': '192.168.0.2'}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1600,4 +1639,8 @@ class Router(NetworkNode):
|
||||
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
|
||||
metric=float(route.get("metric", 0)),
|
||||
)
|
||||
if "default_route" in cfg:
|
||||
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
|
||||
if next_hop_ip_address:
|
||||
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
|
||||
return router
|
||||
|
||||
@@ -100,13 +100,8 @@ class Switch(NetworkNode):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not self.network_interface:
|
||||
self.network_interface = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
|
||||
for port_num, port in self.network_interface.items():
|
||||
port._connected_node = self
|
||||
port.port_num = port_num
|
||||
port.parent = self
|
||||
port.port_num = port_num
|
||||
for i in range(1, self.num_ports + 1):
|
||||
self.connect_nic(SwitchPort())
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
|
||||
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader
|
||||
from primaite.simulator.network.utils import convert_bytes_to_megabits
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -141,3 +141,37 @@ class Frame(BaseModel):
|
||||
def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
|
||||
"""The daa transfer size of the Frame in Mbits."""
|
||||
return convert_bytes_to_megabits(self.size)
|
||||
|
||||
@property
|
||||
def is_broadcast(self) -> bool:
|
||||
"""
|
||||
Determines if the Frame is a broadcast frame.
|
||||
|
||||
A Frame is considered a broadcast frame if the destination MAC address is set to the broadcast address
|
||||
"ff:ff:ff:ff:ff:ff".
|
||||
|
||||
:return: True if the destination MAC address is a broadcast address, otherwise False.
|
||||
"""
|
||||
return self.ethernet.dst_mac_addr.lower() == "ff:ff:ff:ff:ff:ff"
|
||||
|
||||
@property
|
||||
def is_arp(self) -> bool:
|
||||
"""
|
||||
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
|
||||
|
||||
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
|
||||
|
||||
:return: True if the Frame is an ARP packet, otherwise False.
|
||||
"""
|
||||
return self.udp.dst_port == Port.ARP
|
||||
|
||||
@property
|
||||
def is_icmp(self) -> bool:
|
||||
"""
|
||||
Determines if the Frame is an ICMP (Internet Control Message Protocol) packet.
|
||||
|
||||
This check is performed by verifying if the 'icmp' attribute of the Frame instance is present (not None).
|
||||
|
||||
:return: True if the Frame is an ICMP packet (i.e., has an ICMP header), otherwise False.
|
||||
"""
|
||||
return self.icmp is not None
|
||||
|
||||
@@ -11,6 +11,9 @@ class Port(Enum):
|
||||
.. _List of Ports:
|
||||
"""
|
||||
|
||||
UNUSED = -1
|
||||
"An unused port stub."
|
||||
|
||||
NONE = 0
|
||||
"Place holder for a non-port."
|
||||
WOL = 9
|
||||
|
||||
@@ -63,3 +63,8 @@ class Simulation(SimComponent):
|
||||
"""Apply a timestep to the simulation."""
|
||||
super().apply_timestep(timestep)
|
||||
self.network.apply_timestep(timestep)
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
self.network.pre_timestep(timestep)
|
||||
|
||||
@@ -80,7 +80,10 @@ class Application(IOSoftware):
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
self.num_executions = 0 # reset number of executions
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
self.num_executions = 0
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -31,6 +31,7 @@ class DatabaseClient(Application):
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
"""Keep track of the latest query response. Used to determine rewards."""
|
||||
_server_connection_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
@@ -51,10 +52,9 @@ class DatabaseClient(Application):
|
||||
def execute(self) -> bool:
|
||||
"""Execution definition for db client: perform a select query."""
|
||||
self.num_executions += 1 # trying to connect counts as an execution
|
||||
if self.connections:
|
||||
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
|
||||
else:
|
||||
can_connect = self.check_connection(connection_id=str(uuid4()))
|
||||
if not self._server_connection_id:
|
||||
self.connect()
|
||||
can_connect = self.check_connection(connection_id=self._server_connection_id)
|
||||
self._last_connection_successful = can_connect
|
||||
return can_connect
|
||||
|
||||
@@ -80,17 +80,21 @@ class DatabaseClient(Application):
|
||||
self.server_password = server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self, connection_id: Optional[str] = None) -> bool:
|
||||
def connect(self) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not connection_id:
|
||||
connection_id = str(uuid4())
|
||||
if not self._server_connection_id:
|
||||
self._server_connection_id = str(uuid4())
|
||||
|
||||
self.connected = self._connect(
|
||||
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
connection_id=self._server_connection_id,
|
||||
)
|
||||
if not self.connected:
|
||||
self._server_connection_id = None
|
||||
return self.connected
|
||||
|
||||
def check_connection(self, connection_id: str) -> bool:
|
||||
@@ -125,7 +129,7 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self.connections.get(connection_id):
|
||||
if self._server_connection_id:
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
)
|
||||
@@ -149,31 +153,28 @@ class DatabaseClient(Application):
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
)
|
||||
|
||||
def disconnect(self, connection_id: Optional[str] = None) -> bool:
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not len(self.connections):
|
||||
if not self._server_connection_id:
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
return False
|
||||
|
||||
# if no connection provided, disconnect the first connection
|
||||
if not connection_id:
|
||||
connection_id = list(self.connections.keys())[0]
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
payload={"type": "disconnect", "connection_id": self._server_connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.remove_connection(connection_id=connection_id)
|
||||
self.remove_connection(connection_id=self._server_connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
f"{self.name}: DatabaseClient disconnected {self._server_connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
self.connected = False
|
||||
|
||||
@@ -224,18 +225,20 @@ class DatabaseClient(Application):
|
||||
# reset last query response
|
||||
self.last_query_response = None
|
||||
|
||||
if connection_id is None:
|
||||
if self.connections:
|
||||
connection_id = list(self.connections.keys())[-1]
|
||||
# TODO: if the most recent connection dies, it should be automatically cleared.
|
||||
else:
|
||||
connection_id = str(uuid4())
|
||||
connection_id: str
|
||||
|
||||
if not self.connections.get(connection_id):
|
||||
if not self.connect(connection_id=connection_id):
|
||||
return False
|
||||
if not connection_id:
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
self.connect()
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
msg = "Cannot run sql query, could not establish connection with the server."
|
||||
self.parent.sys_log(msg)
|
||||
return False
|
||||
|
||||
# Initialise the tracker of this ID to False
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RansomwareAttackStage(IntEnum):
|
||||
"""
|
||||
Enumeration representing different attack stages of the ransomware script.
|
||||
|
||||
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
|
||||
in the simulation.
|
||||
Each stage represents a specific phase in the attack process.
|
||||
"""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Indicates that the attack has not started yet."
|
||||
DOWNLOAD = 1
|
||||
"Installing the Encryption Script - Testing"
|
||||
INSTALL = 2
|
||||
"The stage where logon procedures are simulated."
|
||||
ACTIVATE = 3
|
||||
"Operating Status Changes"
|
||||
PROPAGATE = 4
|
||||
"Represents the stage of performing a horizontal port scan on the target."
|
||||
COMMAND_AND_CONTROL = 5
|
||||
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
|
||||
PAYLOAD = 6
|
||||
"Stage of actively attacking the target."
|
||||
SUCCEEDED = 7
|
||||
"Indicates the attack has been successfully completed."
|
||||
FAILED = 8
|
||||
"Signifies that the attack has failed."
|
||||
|
||||
|
||||
class RansomwareScript(Application):
|
||||
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
|
||||
|
||||
:ivar payload: The attack stage query payload. (Default Corrupt)
|
||||
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:ivar repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of node which hosts the database."""
|
||||
server_password: Optional[str] = None
|
||||
"""Password required to access the database."""
|
||||
payload: Optional[str] = "ENCRYPT"
|
||||
"Payload String for the payload stage"
|
||||
target_scan_p_of_success: float = 0.9
|
||||
"Probability of the target scan succeeding: Default 0.9"
|
||||
c2_beacon_p_of_success: float = 0.9
|
||||
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
|
||||
ransomware_encrypt_p_of_success: float = 0.9
|
||||
"Probability of the ransomware attack succeeding: Default 0.9"
|
||||
repeat: bool = False
|
||||
"If true, the Denial of Service bot will keep performing the attack."
|
||||
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
|
||||
"The ransomware attack stage. See RansomwareAttackStage Class"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "RansomwareScript"
|
||||
kwargs["port"] = Port.NONE
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
@property
|
||||
def _host_db_client(self) -> DatabaseClient:
|
||||
"""Return the database client that is installed on the same machine as the Ransomware Script."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
|
||||
)
|
||||
return rm
|
||||
|
||||
def _activate(self):
|
||||
"""
|
||||
Simulate the install process as the initial stage of the attack.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' attack state.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.INSTALL:
|
||||
self.sys_log.info(f"{self.name}: Activated!")
|
||||
self.attack_stage = RansomwareAttackStage.ACTIVATE
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the bot, triggering the application loop.
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self) -> bool:
|
||||
"""Calls the parent classes execute method before starting the application loop."""
|
||||
super().run()
|
||||
return True
|
||||
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop of the script, handling the attack process.
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
self._local_download()
|
||||
self._install()
|
||||
self._activate()
|
||||
self._perform_target_scan()
|
||||
self._setup_beacon()
|
||||
self._perform_ransomware_encrypt()
|
||||
|
||||
if self.repeat and self.attack_stage in (
|
||||
RansomwareAttackStage.SUCCEEDED,
|
||||
RansomwareAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def configure(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
server_password: Optional[str] = None,
|
||||
payload: Optional[str] = None,
|
||||
target_scan_p_of_success: Optional[float] = None,
|
||||
c2_beacon_p_of_success: Optional[float] = None,
|
||||
ransomware_encrypt_p_of_success: Optional[float] = None,
|
||||
repeat: bool = True,
|
||||
):
|
||||
"""
|
||||
Configure the Ransomware Script to communicate with a DatabaseService.
|
||||
|
||||
:param server_ip_address: The IP address of the Node the DatabaseService is on.
|
||||
:param server_password: The password on the DatabaseService.
|
||||
:param payload: The attack stage query (Encrypt / Delete)
|
||||
:param target_scan_p_of_success: The probability of success for the target scan stage.
|
||||
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
|
||||
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
|
||||
:param repeat: Whether to repeat attacking once finished.
|
||||
"""
|
||||
if server_ip_address:
|
||||
self.server_ip_address = server_ip_address
|
||||
if server_password:
|
||||
self.server_password = server_password
|
||||
if payload:
|
||||
self.payload = payload
|
||||
if target_scan_p_of_success:
|
||||
self.target_scan_p_of_success = target_scan_p_of_success
|
||||
if c2_beacon_p_of_success:
|
||||
self.c2_beacon_p_of_success = c2_beacon_p_of_success
|
||||
if ransomware_encrypt_p_of_success:
|
||||
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
|
||||
if repeat:
|
||||
self.repeat = repeat
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
|
||||
f"{repeat=}."
|
||||
)
|
||||
|
||||
def _install(self):
|
||||
"""
|
||||
Simulate the install stage in the kill-chain.
|
||||
|
||||
Advances the attack stage to 'ACTIVATE' if successful.
|
||||
|
||||
From this attack stage onwards.
|
||||
the ransomware application is now visible from this point onwardin the observation space.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
|
||||
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
|
||||
downloads_folder = self.file_system.get_folder(folder_name="downloads")
|
||||
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
|
||||
ransomware_file.num_access += 1
|
||||
self.attack_stage = RansomwareAttackStage.INSTALL
|
||||
|
||||
def _setup_beacon(self):
|
||||
"""
|
||||
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
|
||||
|
||||
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
|
||||
|
||||
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
|
||||
by default the success rate is 0.5
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
|
||||
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
|
||||
if simulate_trial(self.c2_beacon_p_of_success):
|
||||
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
|
||||
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
|
||||
if c2c_setup:
|
||||
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
|
||||
|
||||
def _perform_target_scan(self):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PROPAGATE` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
|
||||
if simulate_trial(self.target_scan_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
|
||||
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
|
||||
if port_is_open:
|
||||
self.attack_stage = RansomwareAttackStage.PROPAGATE
|
||||
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Ransomware application is unable to perform it's actions.")
|
||||
self.run()
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
def _perform_ransomware_encrypt(self):
|
||||
"""
|
||||
Execute the Ransomware Encrypt payload on the target.
|
||||
|
||||
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
|
||||
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
|
||||
"""
|
||||
if self._host_db_client is None:
|
||||
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
return
|
||||
|
||||
self._host_db_client.server_ip_address = self.server_ip_address
|
||||
self._host_db_client.server_password = self.server_password
|
||||
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
|
||||
if simulate_trial(self.ransomware_encrypt_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
self.attack_stage = RansomwareAttackStage.SUCCEEDED
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.error("Attack Attempted to launch too quickly")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download(self):
|
||||
"""Downloads itself via the onto the local file_system."""
|
||||
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
|
||||
if self._local_download_verify():
|
||||
self.attack_stage = RansomwareAttackStage.DOWNLOAD
|
||||
else:
|
||||
self.sys_log.info("Malware failed to create a installation location")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.info("Malware failed to download")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download_verify(self) -> bool:
|
||||
"""Verifies a download location - Creates one if needed."""
|
||||
for folder in self.file_system.folders:
|
||||
if self.file_system.folders[folder].name == "downloads":
|
||||
self.file_system.num_file_creations += 1
|
||||
return True
|
||||
|
||||
self.file_system.create_folder("downloads")
|
||||
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
|
||||
return True
|
||||
@@ -88,6 +88,10 @@ class SysLog:
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self.hostname}_sys.log"
|
||||
|
||||
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
|
||||
if to_terminal or SIM_OUTPUT.write_sys_log_to_terminal:
|
||||
print(f"{self.hostname}: ({level}) {msg}")
|
||||
|
||||
def debug(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
Logs a message with the DEBUG level.
|
||||
@@ -97,8 +101,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.debug(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "DEBUG", to_terminal)
|
||||
|
||||
def info(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -109,8 +112,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.info(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "INFO", to_terminal)
|
||||
|
||||
def warning(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -121,8 +123,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.warning(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "WARNING", to_terminal)
|
||||
|
||||
def error(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -133,8 +134,7 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.error(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "ERROR", to_terminal)
|
||||
|
||||
def critical(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
@@ -145,5 +145,4 @@ class SysLog:
|
||||
"""
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.critical(msg)
|
||||
if to_terminal:
|
||||
print(msg)
|
||||
self._write_to_terminal(msg, "CRITICAL", to_terminal)
|
||||
|
||||
@@ -141,8 +141,7 @@ class DatabaseService(Service):
|
||||
"""Returns the database file."""
|
||||
return self.file_system.get_file(folder_name="database", file_name="database.db")
|
||||
|
||||
@property
|
||||
def folder(self) -> Folder:
|
||||
def _return_database_folder(self) -> Folder:
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
@@ -187,7 +186,10 @@ class DatabaseService(Service):
|
||||
}
|
||||
|
||||
def _process_sql(
|
||||
self, query: Literal["SELECT", "DELETE", "INSERT"], query_id: str, connection_id: Optional[str] = None
|
||||
self,
|
||||
query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"],
|
||||
query_id: str,
|
||||
connection_id: Optional[str] = None,
|
||||
) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
@@ -196,6 +198,7 @@ class DatabaseService(Service):
|
||||
- SELECT : returns the data
|
||||
- DELETE : deletes the data
|
||||
- INSERT : inserts the data
|
||||
- ENCRYPT : corrupts the data
|
||||
|
||||
:param query: The SQL query to be executed.
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
@@ -207,7 +210,15 @@ class DatabaseService(Service):
|
||||
return {"status_code": 404, "type": "sql", "data": False}
|
||||
|
||||
if query == "SELECT":
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
@@ -226,6 +237,20 @@ class DatabaseService(Service):
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
elif query == "ENCRYPT":
|
||||
self.file_system.num_file_creations += 1
|
||||
self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
self.db_file.num_access += 1
|
||||
database_folder = self._return_database_folder()
|
||||
database_folder.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
self.file_system.num_file_deletions += 1
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
elif query == "INSERT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {
|
||||
|
||||
@@ -87,13 +87,9 @@ class NTPClient(Service):
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
if not isinstance(payload, NTPPacket):
|
||||
_LOGGER.debug(f"{payload} is not a NTPPacket")
|
||||
_LOGGER.debug(f"{self.name}: Failed to parse NTP update")
|
||||
return False
|
||||
if payload.ntp_reply.ntp_datetime:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: \
|
||||
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
|
||||
)
|
||||
self.time = payload.ntp_reply.ntp_datetime
|
||||
return True
|
||||
|
||||
@@ -124,5 +120,3 @@ class NTPClient(Service):
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
# request time from server
|
||||
self.request_time()
|
||||
else:
|
||||
self.sys_log.debug(f"{self.name} ntp client not running")
|
||||
|
||||
@@ -224,6 +224,10 @@ class Software(SimComponent):
|
||||
if self.health_state_actual == SoftwareHealthState.FIXING:
|
||||
self._update_fix_status()
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
|
||||
class IOSoftware(Software):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user