Merged PR 224: Enhanced probabilistic red agent.

## Summary
This is @<Jake Walker> 's and @<Christopher McCarthy> 's PR.

- The red agent now performs attacks in stages.
- There is now a random probability of each stage being successfully run.
- There are start settings for the red agent which set the start time, frequency and variance for both.

## Test process
Additional unit tests have been added to test new functionality.

## Checklist
- [ ] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #1859
This commit is contained in:
Marek Wolan
2023-11-28 10:47:39 +00:00
committed by Christopher McCarthy
47 changed files with 1464 additions and 495 deletions

View File

@@ -31,7 +31,8 @@ SessionManager.
- `DatabaseClient` and `DatabaseService` created to allow emulation of database actions
- Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup
- Red Agent Services:
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database)
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database). The attack runs in stages with a random, configurable probability of succeeding.
- `DataManipulationAgent` runs the Data Manipulator Bot according to a configured start step, frequency and variance.
- DNS Services: `DNSClient` and `DNSServer`
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`

View File

@@ -18,19 +18,28 @@ The bot is intended to simulate a malicious actor carrying out attacks like:
- Modifying data
on a database server by abusing an application's trusted database connectivity.
The bot performs attacks in the following stages to simulate the real pattern of an attack:
- Logon - *The bot gains credentials and accesses the node.*
- Port Scan - *The bot finds accessible database servers on the network.*
- Attacking - *The bot delivers the payload to the discovered database servers.*
Each of these stages has a random, configurable probability of succeeding (by default 10%). The bot can also be configured to repeat the attack once complete.
Usage
-----
- Create an instance and call ``configure`` to set:
- Target database server IP
- Database password (if needed)
- SQL statement payload
- Target database server IP
- Database password (if needed)
- SQL statement payload
- Probabilities for succeeding each of the above attack stages
- Call ``run`` to connect and execute the statement.
The bot handles connecting, executing the statement, and disconnecting.
In a simulation, the bot can be controlled by using ``DataManipulationAgent`` which calls ``run`` on the bot at configured timesteps.
Example
-------
@@ -51,13 +60,81 @@ Example
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.
Example with ``DataManipulationAgent``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If not using the data manipulation bot manually, it needs to be used with a data manipulation agent. Below is an example section of configuration file for setting up a simulation with data manipulation bot and agent.
.. code-block:: yaml
game_config:
# ...
agents:
- ref: data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
applications:
- application_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 25
frequency: 20
variance: 5
# ...
simulation:
network:
nodes:
- ref: client_1
type: computer
# ... additional configuration here
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
Implementation
--------------
The bot extends ``DatabaseClient`` and leverages its connectivity.
- Uses the Application base class for lifecycle management.
- Credentials and target IP set via ``configure``.
- Credentials, target IP and other options set via ``configure``.
- ``run`` handles connecting, executing statement, and disconnecting.
- SQL payload executed via ``query`` method.
- Results in malicious SQL being executed on remote database server.

View File

@@ -1,5 +1,5 @@
training_config:
rl_framework: RLLIB_single_agent
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
@@ -36,31 +36,26 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
applications:
- application_ref: client_2_web_browser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -69,38 +64,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DELETE"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -110,9 +87,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -562,17 +540,25 @@ simulation:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
0:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
@@ -609,7 +595,7 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
@@ -630,6 +616,10 @@ simulation:
services:
- ref: database_service
type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- ref: database_ftp_client
type: FTPClient
- ref: backup_server
type: server
@@ -640,7 +630,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server
@@ -661,9 +651,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
@@ -677,10 +673,14 @@ simulation:
applications:
- ref: client_2_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1

View File

@@ -15,7 +15,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
@@ -82,7 +81,7 @@ class NodeServiceAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, service_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -98,7 +97,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
self.verb: str = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
@@ -106,7 +105,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
self.verb: str = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
@@ -114,7 +113,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
self.verb: str = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
@@ -122,7 +121,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
self.verb: str = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
@@ -130,7 +129,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
self.verb: str = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
@@ -138,7 +137,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
self.verb: str = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
@@ -146,7 +145,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
self.verb: str = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
@@ -154,7 +153,38 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
self.verb: str = "enable"
class NodeApplicationAbstractAction(AbstractAction):
"""
Base class for application actions.
Any action which applies to an application and uses node_id and application_id as its only two parameters can
inherit from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, application_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id)
if node_uuid is None or application_uuid is None:
return ["do_nothing"]
return ["network", "node", node_uuid, "application", application_uuid, self.verb]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
"""Action which executes an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb: str = "execute"
class NodeFolderAbstractAction(AbstractAction):
@@ -169,7 +199,7 @@ class NodeFolderAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -223,7 +253,7 @@ class NodeFileAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -240,7 +270,7 @@ class NodeFileScanAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "scan"
self.verb: str = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
@@ -248,7 +278,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "checkhash"
self.verb: str = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
@@ -256,7 +286,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "delete"
self.verb: str = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
@@ -264,7 +294,7 @@ class NodeFileRepairAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "repair"
self.verb: str = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
@@ -272,7 +302,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "restore"
self.verb: str = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction):
@@ -280,7 +310,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "corrupt"
self.verb: str = "corrupt"
class NodeAbstractAction(AbstractAction):
@@ -294,7 +324,7 @@ class NodeAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -307,7 +337,7 @@ class NodeOSScanAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "scan"
self.verb: str = "scan"
class NodeShutdownAction(NodeAbstractAction):
@@ -315,7 +345,7 @@ class NodeShutdownAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "shutdown"
self.verb: str = "shutdown"
class NodeStartupAction(NodeAbstractAction):
@@ -323,7 +353,7 @@ class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "startup"
self.verb: str = "startup"
class NodeResetAction(NodeAbstractAction):
@@ -331,7 +361,7 @@ class NodeResetAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "reset"
self.verb: str = "reset"
class NetworkACLAddRuleAction(AbstractAction):
@@ -489,7 +519,7 @@ class NetworkNICAbstractAction(AbstractAction):
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, nic_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -512,7 +542,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "enable"
self.verb: str = "enable"
class NetworkNICDisableAction(NetworkNICAbstractAction):
@@ -520,7 +550,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
self.verb: str = "disable"
class ActionManager:
@@ -536,6 +566,7 @@ class ActionManager:
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
"NODE_FILE_SCAN": NodeFileScanAction,
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
"NODE_FILE_DELETE": NodeFileDeleteAction,
@@ -562,9 +593,11 @@ class ActionManager:
game: "PrimaiteGame", # reference to game for information lookup
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
application_uuids: List[List[str]], # allows mapping index to application
max_folders_per_node: int = 2, # allows calculating shape
max_files_per_folder: int = 2, # allows calculating shape
max_services_per_node: int = 2, # allows calculating shape
max_applications_per_node: int = 10, # allows calculating shape
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
@@ -600,8 +633,8 @@ class ActionManager:
:type act_map: Optional[Dict[int, Dict]]
"""
self.game: "PrimaiteGame" = game
self.sim: Simulation = self.game.simulation
self.node_uuids: List[str] = node_uuids
self.application_uuids: List[List[str]] = application_uuids
self.protocols: List[str] = protocols
self.ports: List[str] = ports
@@ -611,7 +644,7 @@ class ActionManager:
else:
self.ip_address_list = []
for node_uuid in self.node_uuids:
node_obj = self.sim.network.nodes[node_uuid]
node_obj = self.game.simulation.network.nodes[node_uuid]
nics = node_obj.nics
for nic_uuid, nic_obj in nics.items():
self.ip_address_list.append(nic_obj.ip_address)
@@ -622,6 +655,7 @@ class ActionManager:
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_applications": max_applications_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols),
@@ -734,7 +768,7 @@ class ActionManager:
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
node = self.game.simulation.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
@@ -752,7 +786,7 @@ class ActionManager:
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
node = self.game.simulation.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
if len(folder_uuids) <= folder_idx:
return None
@@ -771,10 +805,22 @@ class ActionManager:
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
node = self.game.simulation.network.nodes[node_uuid]
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
"""Get the application UUID corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param application_idx: The index of the service on the node.
:type application_idx: int
:return: The UUID of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
return self.application_uuids[node_idx][application_idx]
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
"""Get the internet protocol corresponding to the given index.
@@ -819,7 +865,7 @@ class ActionManager:
:rtype: str
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node_obj = self.sim.network.nodes[node_uuid]
node_obj = self.game.simulation.network.nodes[node_uuid]
nics = list(node_obj.nics.keys())
if len(nics) <= nic_idx:
return None

View File

@@ -0,0 +1,48 @@
import random
from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
class DataManipulationAgent(AbstractScriptedAgent):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
data_manipulation_bots: List["DataManipulationBot"] = []
next_execution_timestep: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.game.step_counter
if current_timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}

View File

@@ -1,13 +1,64 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
pass
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AgentStartSettings":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions"
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
"""Construct agent settings from a config dictionary.
:param config: A dict of options for the agent settings.
:type config: Dict
:return: The agent settings.
:rtype: AgentSettings
"""
if config is None:
return cls()
return cls(**config)
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
@@ -18,6 +69,7 @@ class AbstractAgent(ABC):
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings] = None,
) -> None:
"""
Initialize an agent.
@@ -35,10 +87,7 @@ class AbstractAgent(ABC):
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
self.agent_settings = agent_settings or AgentSettings()
def update_observation(self, state: Dict) -> ObsType:
"""

View File

@@ -162,7 +162,7 @@ class ServiceObservation(AbstractObservation):
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]])
class LinkObservation(AbstractObservation):

View File

@@ -213,7 +213,7 @@ class WebServer404Penalty(AbstractReward):
_LOGGER.warn(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref].uuid
service_uuid = game.ref_map_services[service_ref]
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"

View File

@@ -1,5 +1,4 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from copy import deepcopy
from ipaddress import IPv4Address
from typing import Dict, List
@@ -7,10 +6,11 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
@@ -18,14 +18,14 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -57,9 +57,6 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -75,16 +72,16 @@ class PrimaiteGame:
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
self.ref_map_nodes: Dict[str, Node] = {}
self.ref_map_nodes: Dict[str, str] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
self.ref_map_services: Dict[str, str] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
self.ref_map_applications: Dict[str, str] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
self.ref_map_links: Dict[str, str] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
def step(self):
@@ -157,7 +154,7 @@ class PrimaiteGame:
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
self.simulation.reset_component_for_episode(episode=self.episode_counter)
def close(self) -> None:
"""Close the game, this will close the simulation."""
@@ -187,10 +184,6 @@ class PrimaiteGame:
sim = game.simulation
net = sim.network
game.ref_map_nodes: Dict[str, Node] = {}
game.ref_map_services: Dict[str, Service] = {}
game.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
for node_cfg in nodes_cfg:
@@ -203,6 +196,7 @@ class PrimaiteGame:
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg["dns_server"],
operating_state=NodeOperatingState.ON,
)
elif n_type == "server":
new_node = Server(
@@ -211,16 +205,26 @@ class PrimaiteGame:
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg.get("dns_server"),
operating_state=NodeOperatingState.ON,
)
elif n_type == "switch":
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
new_node = Switch(
hostname=node_cfg["hostname"],
num_ports=node_cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
elif n_type == "router":
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
new_node = Router(
hostname=node_cfg["hostname"],
num_ports=node_cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
# new_node.enable_port(port_num)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
@@ -239,6 +243,7 @@ class PrimaiteGame:
print("invalid node type")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
service_types_mapping = {
@@ -247,13 +252,14 @@ class PrimaiteGame:
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"DataManipulationBot": DataManipulationBot,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
}
if service_type in service_types_mapping:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
game.ref_map_services[service_ref] = new_service
game.ref_map_services[service_ref] = new_service.uuid
else:
print(f"service type not found {service_type}")
# service-dependent options
@@ -268,30 +274,49 @@ class PrimaiteGame:
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
if service_type == "DatabaseService":
if "options" in service_cfg:
opt = service_cfg["options"]
if "backup_server_ip" in opt:
new_service.configure_backup(backup_server=IPv4Address(opt["backup_server_ip"]))
new_service.start()
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
new_application = None
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
application_types_mapping = {
"WebBrowser": WebBrowser,
"DataManipulationBot": DataManipulationBot,
}
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application
game.ref_map_applications[application_ref] = new_application.uuid
else:
print(f"application type not found {application_type}")
if application_type == "DataManipulationBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
payload=opt.get("payload"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "WebBrowser":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.target_url = opt.get("target_url")
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
net.add_node(new_node)
new_node.power_on()
game.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object
game.ref_map_nodes[node_ref] = new_node.uuid
# 2. create links between nodes
for link_cfg in links_cfg:
@@ -323,11 +348,25 @@ class PrimaiteGame:
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
action_space_cfg["options"]["application_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
if "applications" in action_node_option:
node_application_uuids = []
for application_option in action_node_option["applications"]:
# TODO: fix inconsistency with node uuids and application uuids. The node object get added to
# node_uuid, whereas here the application gets added by uuid.
application_uuid = game.ref_map_applications[application_option["application_ref"]]
node_application_uuids.append(application_uuid)
action_space_cfg["options"]["application_uuids"].append(node_application_uuids)
else:
action_space_cfg["options"]["application_uuids"].append([])
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
@@ -345,6 +384,8 @@ class PrimaiteGame:
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
@@ -353,6 +394,7 @@ class PrimaiteGame:
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
@@ -365,16 +407,17 @@ class PrimaiteGame:
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
else:
print("agent type not found")
game._simulation_initial_state = deepcopy(game.simulation) # noqa
game.simulation.set_original_state()
return game

View File

@@ -0,0 +1,16 @@
from random import random
def simulate_trial(p_of_success: float) -> bool:
"""
Simulates the outcome of a single trial in a Bernoulli process.
This function returns True with a probability 'p_of_success', simulating a success outcome in a single
trial of a Bernoulli process. When this function is executed multiple times, the set of outcomes follows
a binomial distribution. This is useful in scenarios where one needs to model or simulate events that
have two possible outcomes (success or failure) with a fixed probability of success.
:param p_of_success: The probability of success in a single trial, ranging from 0 to 1.
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
"""
return random() < p_of_success

View File

@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n"
]
}
],
"source": [
"from primaite.session.session import PrimaiteSession\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.simulator.system.services.database.database_service import DatabaseService\n",
"\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"installing DNSServer on node domain_controller\n",
"installing DatabaseClient on node web_server\n",
"installing WebServer on node web_server\n",
"installing DatabaseService on node database_server\n",
"installing FTPClient on node database_server\n",
"installing FTPServer on node backup_server\n",
"installing DNSClient on node client_1\n",
"installing DNSClient on node client_2\n"
]
}
],
"source": [
"\n",
"with open(example_config_path(),'r') as cfgfile:\n",
" cfg = yaml.safe_load(cfgfile)\n",
"game = PrimaiteGame.from_config(cfg)\n",
"net = game.simulation.network\n",
"database_server = net.get_node_by_hostname('database_server')\n",
"web_server = net.get_node_by_hostname('web_server')\n",
"client_1 = net.get_node_by_hostname('client_1')\n",
"\n",
"db_service = database_server.software_manager.software[\"DatabaseService\"]\n",
"db_client = web_server.software_manager.software[\"DatabaseClient\"]\n",
"# db_client.run()\n",
"db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n",
"db_manipulation_bot.port_scan_p_of_success=1.0\n",
"db_manipulation_bot.data_manipulation_p_of_success=1.0\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"db_client.run()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_service.backup_database()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"db_manipulation_bot.run()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_service.restore_backup()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_manipulation_bot.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client_1.ping(database_server.ethernet_port[1].ip_address)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import validate_call, BaseModel"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class A(BaseModel):\n",
" x:int\n",
"\n",
" @validate_call\n",
" def increase_x(self, by:int) -> None:\n",
" self.x += 1"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"my_a = A(x=3)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell://wsl%2Bubuntu/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb#X23sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n",
"File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float"
]
}
],
"source": [
"my_a.increase_x(3.2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -37,7 +37,7 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {}
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:

View File

@@ -88,16 +88,16 @@ class PrimaiteSession:
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
io_manager = SessionIO(SessionIOSettings(**io_settings))
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"game": game})

View File

@@ -153,6 +153,8 @@ class SimComponent(BaseModel):
uuid: str
"""The component UUID."""
_original_state: Dict = {}
def __init__(self, **kwargs):
if not kwargs.get("uuid"):
kwargs["uuid"] = str(uuid4())
@@ -160,6 +162,16 @@ class SimComponent(BaseModel):
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
# @abstractmethod
def set_original_state(self):
"""Sets the original state."""
pass
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for key, value in self._original_state.items():
self.__setattr__(key, value)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager for this component.
@@ -227,14 +239,6 @@ class SimComponent(BaseModel):
"""
pass
def reset_component_for_episode(self, episode: int):
"""
Reset this component to its original state for a new episode.
Override this method with anything that needs to happen within the component for it to be reset.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.

View File

@@ -42,6 +42,19 @@ class Account(SimComponent):
"Account Type, currently this can be service account (used by apps) or user account."
enabled: bool = True
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"num_logons",
"num_logoffs",
"num_group_changes",
"username",
"password",
"account_type",
"enabled",
}
self._original_state = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -73,6 +73,18 @@ class File(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
@property
def path(self) -> str:
"""

View File

@@ -35,6 +35,36 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def set_original_state(self):
"""Sets the original state."""
for folder in self.folders.values():
folder.set_original_state()
super().set_original_state()
# Capture a list of all 'original' file uuids
self._original_state["original_folder_uuids"] = list(self.folders.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
# Move any 'original' folder that have been deleted back to folders
original_folder_uuids = self._original_state.pop("original_folder_uuids")
for uuid in original_folder_uuids:
if uuid in self.deleted_folders:
self.folders[uuid] = self.deleted_folders.pop(uuid)
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
# Now clear all non-original folders created by agent
current_folder_uuids = list(self.folders.keys())
for uuid in current_folder_uuids:
if uuid not in original_folder_uuids:
self.folders.pop(uuid)
# Now reset all remaining folders
for folder in self.folders.values():
folder.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent):
deleted: bool = False
"If true, the FileSystemItem was deleted."
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -51,6 +51,44 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def set_original_state(self):
"""Sets the original state."""
for file in self.files.values():
file.set_original_state()
super().set_original_state()
vals_to_include = {
"scan_duration",
"scan_countdown",
"red_scan_duration",
"red_scan_countdown",
"restore_duration",
"restore_countdown",
}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_file_uuids"] = list(self.files.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
# Move any 'original' file that have been deleted back to files
original_file_uuids = self._original_state.pop("original_file_uuids")
for uuid in original_file_uuids:
if uuid in self.deleted_files:
self.files[uuid] = self.deleted_files.pop(uuid)
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
# Now clear all non-original files created by agent
current_file_uuids = list(self.files.keys())
for uuid in current_file_uuids:
if uuid not in original_file_uuids:
self.files.pop(uuid)
# Now reset all remaining files
for file in self.files.values():
file.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(

View File

@@ -43,6 +43,20 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
def set_original_state(self):
"""Sets the original state."""
for node in self.nodes.values():
node.set_original_state()
for link in self.links.values():
link.set_original_state()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.reset_component_for_episode(episode)
for link in self.links.values():
link.reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()

View File

@@ -121,6 +121,21 @@ class NIC(SimComponent):
_LOGGER.error(msg)
raise ValueError(msg)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -308,6 +323,14 @@ class SwitchPort(SimComponent):
kwargs["mac_address"] = generate_mac_address()
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -454,6 +477,14 @@ class Link(SimComponent):
self.endpoint_b.connect_link(self)
self.endpoint_up()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"bandwidth", "current_load"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -536,15 +567,6 @@ class Link(SimComponent):
return True
return False
def reset_component_for_episode(self, episode: int):
"""
Link reset function.
Reset:
- returns the link current_load to 0.
"""
self.current_load = 0
def __str__(self) -> str:
return f"{self.endpoint_a}<-->{self.endpoint_b}"
@@ -584,6 +606,10 @@ class ARPCache:
)
print(table)
def clear(self):
"""Clears the arp cache."""
self.arp.clear()
def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False):
"""
Add an ARP entry to the cache.
@@ -756,6 +782,10 @@ class ICMP:
self.arp: ARPCache = arp_cache
self.request_replies = {}
def clear(self):
"""Clears the ICMP request replies tracker."""
self.request_replies.clear()
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
"""
Process an ICMP packet, including handling echo requests and replies.
@@ -959,6 +989,62 @@ class Node(SimComponent):
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
self._install_system_software()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
for software in self.software_manager.software.values():
software.set_original_state()
self.file_system.set_original_state()
for nic in self.nics.values():
nic.set_original_state()
vals_to_include = {
"hostname",
"default_gateway",
"operating_state",
"revealed_to_red",
"start_up_duration",
"start_up_countdown",
"shut_down_duration",
"shut_down_countdown",
"is_resetting",
"node_scan_duration",
"node_scan_countdown",
"red_scan_countdown",
}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
# Reset ARP Cache
self.arp.clear()
# Reset ICMP
self.icmp.clear()
# Reset Session Manager
self.session_manager.clear()
# Reset software
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
# Reset File System
self.file_system.reset_component_for_episode(episode)
# Reset all Nics
for nic in self.nics.values():
nic.reset_component_for_episode(episode)
#
if episode and self.sys_log:
self.sys_log.current_episode = episode
self.sys_log.setup_logger()
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
@@ -1442,99 +1528,3 @@ class Node(SimComponent):
if isinstance(item, Service):
return item.uuid in self.services
return None
class Switch(Node):
"""A class representing a Layer 2 network switch."""
num_ports: int = 24
"The number of ports on the switch."
switch_ports: Dict[int, SwitchPort] = {}
"The SwitchPorts on the switch."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port._connected_node = self
port.parent = self
port.port_num = port_num
def show(self):
"""Prints a table of the SwitchPorts on the Switch."""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
for port_num, port in self.switch_ports.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return {
"uuid": self.uuid,
"num_ports": self.num_ports, # redundant?
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
}
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
mac_table_port = self.mac_address_table.get(mac_address)
if not mac_table_port:
self.mac_address_table[mac_address] = switch_port
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
else:
if mac_table_port != switch_port:
self.mac_address_table.pop(mac_address)
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
self._add_mac_table_entry(mac_address, switch_port)
def forward_frame(self, frame: Frame, incoming_port: SwitchPort):
"""
Forward a frame to the appropriate port based on the destination MAC address.
:param frame: The Frame to be forwarded.
:param incoming_port: The port number from which the frame was received.
"""
src_mac = frame.ethernet.src_mac_addr
dst_mac = frame.ethernet.dst_mac_addr
self._add_mac_table_entry(src_mac, incoming_port)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.switch_ports.values():
if port != incoming_port:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):
"""
Disconnect a given link from the specified port number on the switch.
:param link: The Link object to be disconnected.
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.switch_ports.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)
raise NetworkError(msg)
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)
port.disconnect_link()

View File

@@ -52,6 +52,11 @@ class ACLRule(SimComponent):
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -93,6 +98,18 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.implicit_rule.set_original_state()
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -337,6 +354,11 @@ class RouteEntry(SimComponent):
kwargs[key] = IPv4Address(kwargs[key])
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
self._original_values = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteEntry.
@@ -368,6 +390,18 @@ class RouteTable(SimComponent):
routes: List[RouteEntry] = []
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.routes.clear()
self.routes = self._original_state["routes_orig"]
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteTable.
@@ -638,6 +672,26 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.acl.set_original_state()
self.route_table.set_original_state()
vals_to_include = {"num_ports"}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.arp.clear()
self.acl.reset_component_for_episode(episode)
self.route_table.reset_component_for_episode(episode)
for i, nic in self.ethernet_ports.items():
nic.reset_component_for_episode(episode)
self.enable_port(i)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))
@@ -730,6 +784,7 @@ class Router(Node):
dst_ip_address=dst_ip_address,
dst_port=dst_port,
)
if not permitted:
at_port = self._get_port_of_nic(from_nic)
self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}")
@@ -763,6 +818,7 @@ class Router(Node):
nic.ip_address = ip_address
nic.subnet_mask = subnet_mask
self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}")
self.set_original_state()
def enable_port(self, port: int):
"""

View File

@@ -140,7 +140,12 @@ def arcd_uc2_network() -> Network:
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
db_manipulation_bot.configure(
server_ip_address=IPv4Address("192.168.1.14"),
payload="DELETE",
port_scan_p_of_success=1.0,
data_manipulation_p_of_success=1.0,
)
# Client 2
client_2 = Computer(
@@ -152,6 +157,8 @@ def arcd_uc2_network() -> Network:
operating_state=NodeOperatingState.ON,
)
client_2.power_on()
web_browser = client_2.software_manager.software["WebBrowser"]
web_browser.target_url = "http://arcd.com/users/"
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller

View File

@@ -1,4 +1,4 @@
from enum import Enum
from enum import Enum, IntEnum
from primaite.simulator.network.protocols.packet import DataPacket
@@ -25,7 +25,7 @@ class HttpRequestMethod(Enum):
"""Apply partial modifications to a resource."""
class HttpStatusCode(Enum):
class HttpStatusCode(IntEnum):
"""List of available HTTP Statuses."""
OK = 200

View File

@@ -9,7 +9,7 @@ class Simulation(SimComponent):
"""Top-level simulation object which holds a reference to all other parts of the simulation."""
network: Network
domain: DomainController
# domain: DomainController
def __init__(self, **kwargs):
"""Initialise the Simulation."""
@@ -21,6 +21,14 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
self.network.set_original_state()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.network.reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
# pass through network requests to the network objects

View File

@@ -41,6 +41,12 @@ class Application(IOSoftware):
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -90,6 +96,10 @@ class Application(IOSoftware):
self.sys_log.info(f"Running Application {self.name}")
self.operating_state = ApplicationOperatingState.RUNNING
def _application_loop(self):
"""The main application loop."""
pass
def close(self) -> None:
"""Close the Application."""
if self.operating_state == ApplicationOperatingState.RUNNING:
@@ -98,23 +108,11 @@ class Application(IOSoftware):
def install(self) -> None:
"""Install Application."""
if self._can_perform_action():
return
super().install()
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Installing Application {self.name}")
self.operating_state = ApplicationOperatingState.INSTALLING
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.

View File

@@ -31,6 +31,13 @@ class DatabaseClient(Application):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"server_ip_address", "server_password", "connected"}
self._original_state.update(self.model_dump(include=vals_to_include))
def describe_state(self) -> Dict:
"""
@@ -78,11 +85,11 @@ class DatabaseClient(Application):
"""
if is_reattempt:
if self.connected:
self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} authorised")
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
self.server_ip_address = server_ip_address
return self.connected
else:
self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} declined")
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
return False
payload = {"type": "connect_request", "password": password}
software_manager: SoftwareManager = self.software_manager
@@ -135,8 +142,8 @@ class DatabaseClient(Application):
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
self.operating_state = ApplicationOperatingState.RUNNING
self.connect()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
def query(self, sql: str, is_reattempt: bool = False) -> bool:
"""

View File

@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from urllib.parse import urlparse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
HttpRequestPacket,
@@ -21,6 +22,8 @@ class WebBrowser(Application):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
target_url: Optional[str] = None
domain_name_ip_address: Optional[IPv4Address] = None
"The IP address of the domain name for the webpage."
@@ -35,8 +38,23 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.set_original_state()
self.run()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa
)
return rm
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of the WebBrowser.
@@ -47,16 +65,9 @@ class WebBrowser(Application):
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
"""Reset the original state of the SimComponent."""
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
self.domain_name_ip_address = None
self.latest_response = None
def get_webpage(self, url: str) -> bool:
def get_webpage(self) -> bool:
"""
Retrieve the webpage.
@@ -65,6 +76,7 @@ class WebBrowser(Application):
:param: url: The address of the web page the browser requests
:type: url: str
"""
url = self.target_url
if not self._can_perform_action():
return False
@@ -79,7 +91,6 @@ class WebBrowser(Application):
# get the IP address of the domain name via DNS
dns_client: DNSClient = self.software_manager.software["DNSClient"]
domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname)
# if domain does not exist, the request fails

View File

@@ -34,9 +34,12 @@ class PacketCapture:
"The IP address associated with the PCAP logs."
self.switch_port_number = switch_port_number
"The SwitchPort number."
self._setup_logger()
def _setup_logger(self):
self.current_episode: int = 1
self.setup_logger()
def setup_logger(self):
"""Set up the logger configuration."""
log_path = self._get_log_path()
@@ -75,7 +78,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = SIM_OUTPUT.path / self.hostname
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -93,6 +93,11 @@ class SessionManager:
"""
pass
def clear(self):
"""Clears the sessions."""
self.sessions_by_key.clear()
self.sessions_by_uuid.clear()
@staticmethod
def _get_session_key(
frame: Frame, inbound_frame: bool = True

View File

@@ -31,9 +31,10 @@ class SysLog:
:param hostname: The hostname associated with the system logs being recorded.
"""
self.hostname = hostname
self._setup_logger()
self.current_episode: int = 1
self.setup_logger()
def _setup_logger(self):
def setup_logger(self):
"""
Configures the logger for this SysLog instance.
@@ -80,7 +81,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT.path / self.hostname
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -24,6 +24,12 @@ class Process(Software):
operating_state: ProcessOperatingState
"The current operating state of the Process."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -38,6 +38,23 @@ class DatabaseService(Service):
self._db_file: File
self._create_db_file()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {
"password",
"connections",
"backup_server",
"latest_backup_directory",
"latest_backup_file_name",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.connections.clear()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.

View File

@@ -29,6 +29,17 @@ class DNSClient(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"dns_server"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_cache.clear()
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.

View File

@@ -28,6 +28,22 @@ class DNSServer(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"dns_table"}
self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
print("dns reset")
print("DNSServer original state", self._original_state)
self.dns_table.clear()
for key, value in self._original_state["dns_table_orig"].items():
self.dns_table[key] = value
super().reset_component_for_episode(episode)
self.show()
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -68,15 +84,6 @@ class DNSServer(Service):
self.dns_table[domain_name] = domain_ip_address
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(
self,
payload: Any,

View File

@@ -1,27 +1,67 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
class DataManipulationBot(DatabaseClient):
class DataManipulationAttackStage(IntEnum):
"""
Red Agent Data Integration Service.
Enumeration representing different stages of a data manipulation attack.
The Service represents a bot that causes files/folders in the File System to
become corrupted.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the
simulation. Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
LOGON = 1
"The stage where logon procedures are simulated."
PORT_SCAN = 2
"Represents the stage of performing a horizontal port scan on the target."
ATTACKING = 3
"Stage of actively attacking the target."
SUCCEEDED = 4
"Indicates the attack has been successfully completed."
FAILED = 5
"Signifies that the attack has failed."
class DataManipulationBot(DatabaseClient):
"""A bot that simulates a script which performs a SQL injection attack."""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED
repeat: bool = False
"Whether to repeat attacking once finished."
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
return rm
def configure(
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
self,
server_ip_address: IPv4Address,
server_password: Optional[str] = None,
payload: Optional[str] = None,
port_scan_p_of_success: float = 0.1,
data_manipulation_p_of_success: float = 0.1,
repeat: bool = False,
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
@@ -29,26 +69,111 @@ class DataManipulationBot(DatabaseClient):
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The data manipulation query payload.
:param port_scan_p_of_success: The probability of success for the port scan stage.
:param data_manipulation_p_of_success: The probability of success for the data manipulation stage.
:param repeat: Whether to repeat attacking once finished.
"""
self.server_ip_address = server_ip_address
self.payload = payload
self.server_password = server_password
self.port_scan_p_of_success = port_scan_p_of_success
self.data_manipulation_p_of_success = data_manipulation_p_of_success
self.repeat = repeat
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
f"{repeat=}."
)
def run(self):
"""Run the DataManipulationBot."""
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Attempting to start the {self.name}")
super().run()
else:
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_ip_address and payload.")
def _logon(self):
"""
Simulate the logon process as the initial stage of the attack.
def attack(self):
"""Run the data manipulation attack."""
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
Advances the attack stage to `LOGON` if successful.
"""
if self.attack_stage == DataManipulationAttackStage.NOT_STARTED:
# Bypass this stage as we're not dealing with logon for now
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DataManipulationAttackStage.LOGON
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PORT_SCAN` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == DataManipulationAttackStage.LOGON:
# perform a port scan to identify that the SQL port is open on the server
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the port scan
port_is_open = True # Temporary; later we can implement NMAP port scan.
if port_is_open:
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DataManipulationAttackStage.PORT_SCAN
def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1):
"""
Execute the data manipulation attack on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing data manipulation, by default 0.1.
"""
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
# perform the actual data manipulation attack
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing data manipulation")
# perform the attack
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
attack_successful = True
if attack_successful:
self.sys_log.info(f"{self.name}: Data manipulation successful")
self.attack_stage = DataManipulationAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Data manipulation failed")
self.attack_stage = DataManipulationAttackStage.FAILED
def run(self):
"""
Run the Data Manipulation Bot.
Calls the parent classes execute method before starting the application loop.
"""
super().run()
self._application_loop()
def _application_loop(self):
"""
The main application loop of the bot, handling the attack process.
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if self.operating_state != ApplicationOperatingState.RUNNING:
return
if self.server_ip_address and self.payload and self.operating_state:
self.sys_log.info(f"{self.name}: Running")
self._logon()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success)
if self.repeat and self.attack_stage in (
DataManipulationAttackStage.SUCCEEDED,
DataManipulationAttackStage.FAILED,
):
self.attack_stage = DataManipulationAttackStage.NOT_STARTED
else:
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, triggering the application loop.
:param timestep: The timestep value to update the bot's state.
"""
self._application_loop()

View File

@@ -80,6 +80,12 @@ class Service(IOSoftware):
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
@@ -107,15 +113,6 @@ class Service(IOSoftware):
state["health_state_visible"] = self.health_state_visible
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def stop(self) -> None:
"""Stop the service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:

View File

@@ -17,7 +17,21 @@ from primaite.simulator.system.services.service import Service
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
last_response_status_code: Optional[HttpStatusCode] = None
_last_response_status_code: Optional[HttpStatusCode] = None
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self._last_response_status_code = None
super().reset_component_for_episode(episode)
@property
def last_response_status_code(self) -> HttpStatusCode:
"""The latest http response code."""
return self._last_response_status_code
@last_response_status_code.setter
def last_response_status_code(self, val: Any):
self._last_response_status_code = val
def describe_state(self) -> Dict:
"""
@@ -30,8 +44,9 @@ class WebServer(Service):
"""
state = super().describe_state()
state["last_response_status_code"] = (
self.last_response_status_code.value if self.last_response_status_code else None
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
print(state)
return state
def __init__(self, **kwargs):

View File

@@ -90,6 +90,19 @@ class Software(SimComponent):
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"name",
"health_state_actual",
"health_state_visible",
"criticality",
"patching_count",
"scanning_count",
"revealed_to_red",
}
self._original_state = self.model_dump(include=vals_to_include)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -132,16 +145,6 @@ class Software(SimComponent):
)
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the software component for a new episode.
This method should ensure the software is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a
"reset" should be implemented in subclasses.
"""
pass
def set_health_state(self, health_state: SoftwareHealthState) -> None:
"""
Assign a new health state to this software.
@@ -204,6 +207,12 @@ class IOSoftware(Software):
port: Port
"The port to which the software is connected."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -27,14 +27,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
@@ -48,10 +40,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -60,38 +53,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -101,9 +76,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -652,9 +628,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -52,10 +52,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -64,50 +65,32 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -656,9 +639,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -58,10 +58,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -70,38 +71,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -111,9 +94,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender1
team: BLUE
@@ -1093,9 +1077,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -56,10 +56,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -68,38 +69,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -109,9 +92,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -660,9 +644,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -52,10 +52,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -64,38 +65,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -105,9 +88,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -656,9 +640,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -23,7 +23,6 @@ def test_data_manipulation(uc2_network):
# Now we run the DataManipulationBot
db_manipulation_bot.run()
db_manipulation_bot.attack()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT")

View File

@@ -4,7 +4,6 @@ from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.http import HttpStatusCode
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.service import ServiceOperatingState
def test_web_page_home_page(uc2_network):
@@ -12,9 +11,10 @@ def test_web_page_home_page(uc2_network):
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_client.target_url = "http://arcd.com/"
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage("http://arcd.com/") is True
assert web_client.get_webpage() is True
# latest reponse should have status code 200
assert web_client.latest_response is not None
@@ -28,7 +28,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network):
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage("http://arcd.com/users/") is True
assert web_client.get_webpage() is True
# latest reponse should have status code 200
assert web_client.latest_response is not None
@@ -42,11 +42,12 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network):
web_client.run()
web_server: Server = uc2_network.get_node_by_hostname("web_server")
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address
web_client.target_url = f"http://{web_server_ip}/users/"
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage(f"http://{web_server_ip}/users/") is True
assert web_client.get_webpage() is True
# latest response should have status code 200
assert web_client.latest_response is not None

View File

@@ -1,20 +1,73 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.services.red_services.data_manipulation_bot import (
DataManipulationAttackStage,
DataManipulationBot,
)
def test_creation():
@pytest.fixture(scope="function")
def dm_client() -> Node:
network = arcd_uc2_network()
return network.get_node_by_hostname("client_1")
client_1: Node = network.get_node_by_hostname("client_1")
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
@pytest.fixture
def dm_bot(dm_client) -> DataManipulationBot:
return dm_client.software_manager.software["DataManipulationBot"]
def test_create_dm_bot(dm_client):
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"]
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.payload == "DELETE"
def test_dm_bot_logon(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.NOT_STARTED
dm_bot._logon()
assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON
def test_dm_bot_perform_port_scan_no_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.LOGON
dm_bot._perform_port_scan(p_of_success=0.0)
assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON
def test_dm_bot_perform_port_scan_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.LOGON
dm_bot._perform_port_scan(p_of_success=1.0)
assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN
def test_dm_bot_perform_data_manipulation_no_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN
dm_bot._perform_data_manipulation(p_of_success=0.0)
assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN
def test_dm_bot_perform_data_manipulation_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN
dm_bot.operating_state = ApplicationOperatingState.RUNNING
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
assert dm_bot.connected