#1859 - Started giving the red agent some 'intelligence' and a sense of a state. Changed Application.run to .execute.

This commit is contained in:
Chris McCarthy
2023-11-13 15:55:14 +00:00
parent 815bdfe603
commit 23fd9c3839
10 changed files with 151 additions and 34 deletions

View File

@@ -19,10 +19,10 @@ class GATERLAgent(AbstractGATEAgent):
def __init__(
self,
agent_name: str | None,
action_space: ActionManager | None,
observation_space: ObservationSpace | None,
reward_function: RewardFunction | None,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(agent_name, action_space, observation_space, reward_function)
self.most_recent_action: ActType

View File

@@ -109,6 +109,8 @@ class RandomAgent(AbstractScriptedAgent):
"""
return self.action_space.get_action(self.action_space.space.sample())
class DataManipulationAgent(AbstractScriptedAgent):
pass
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""

View File

@@ -0,0 +1,16 @@
from random import random
def simulate_trial(p_of_success: float):
"""
Simulates the outcome of a single trial in a Bernoulli process.
This function returns True with a probability 'p_of_success', simulating a success outcome in a single
trial of a Bernoulli process. When this function is executed multiple times, the set of outcomes follows
a binomial distribution. This is useful in scenarios where one needs to model or simulate events that
have two possible outcomes (success or failure) with a fixed probability of success.
:param p_of_success: The probability of success in a single trial, ranging from 0 to 1.
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
"""
return random() < p_of_success

View File

@@ -60,7 +60,7 @@ class PrimaiteGATEClient(GATEClient):
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> int | None:
def seed(self) -> Optional[int]:
"""The seed to use for the environment's random number generator."""
return self.parent_session.training_options.seed
@@ -115,7 +115,7 @@ class PrimaiteGATEClient(GATEClient):
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ObsType, Dict]:
"""Reset the environment.
This method is called when the environment is initialized and at the end of each episode.

View File

@@ -65,6 +65,10 @@ class Application(IOSoftware):
self.sys_log.info(f"Running Application {self.name}")
self.operating_state = ApplicationOperatingState.RUNNING
def _application_loop(self):
"""THe main application loop."""
pass
def close(self) -> None:
"""Close the Application."""
if self.operating_state == ApplicationOperatingState.RUNNING:

View File

@@ -128,11 +128,11 @@ class DatabaseClient(Application):
)
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
def run(self) -> None:
def execute(self) -> None:
"""Run the DatabaseClient."""
super().run()
self.operating_state = ApplicationOperatingState.RUNNING
self.connect()
super().execute()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
def query(self, sql: str) -> bool:
"""

View File

@@ -30,7 +30,7 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.run()
self.execute()
def describe_state(self) -> Dict:
"""

View File

@@ -1,27 +1,46 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite.game.science import simulate_trial
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
class DataManipulationAttackStage(IntEnum):
"""
Enumeration representing different stages of a data manipulation attack.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the
simulation. Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
LOGON = 1
"The stage where logon procedures are simulated."
PORT_SCAN = 2
"Represents the stage of performing a horizontal port scan on the target."
ATTACKING = 3
"Stage of actively attacking the target."
COMPLETE = 4
"Indicates the attack has been successfully completed."
FAILED = 5
"Signifies that the attack has failed."
class DataManipulationBot(DatabaseClient):
"""
Red Agent Data Integration Service.
The Service represents a bot that causes files/folders in the File System to
become corrupted.
"""
"""A bot that simulates a script which performs a SQL injection attack."""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def configure(
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
@@ -37,15 +56,92 @@ class DataManipulationBot(DatabaseClient):
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
)
def run(self):
"""Run the DataManipulationBot."""
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Attempting to start the {self.name}")
super().run()
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
def _logon(self):
"""
Simulate the logon process as the initial stage of the attack.
Advances the attack stage to `LOGON` if successful.
"""
if self.attack_stage == DataManipulationAttackStage.NOT_STARTED:
# Bypass this stage as we're not dealing with logon for now
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DataManipulationAttackStage.LOGON
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PORT_SCAN` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == DataManipulationAttackStage.LOGON:
# perform a port scan to identify that the SQL port is open on the server
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the port scan
port_is_open = True # Temporary; later we can implement NMAP port scan.
if port_is_open:
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DataManipulationAttackStage.PORT_SCAN
def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1):
"""
Execute the data manipulation attack on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing data manipulation, by default 0.1.
"""
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
# perform the actual data manipulation attack
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the attack
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
attack_successful = True
if attack_successful:
self.sys_log.info(f"{self.name}: Performing port scan")
self.attack_stage = DataManipulationAttackStage.COMPLETE
else:
self.sys_log.info(f"{self.name}: Performing port scan")
self.attack_stage = DataManipulationAttackStage.FAILED
def execute(self):
"""
Execute the Data Manipulation Bot
Calls the parent classes execute method before starting the application loop.
"""
super().execute()
self._application_loop()
def _application_loop(self):
"""
The main application loop of the bot, handling the attack process.
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if self.operating_state != ApplicationOperatingState.RUNNING:
return
if self.server_ip_address and self.payload and self.operating_state:
self.sys_log.info(f"{self.name}: Running")
self._logon()
self._perform_port_scan()
self._perform_data_manipulation()
else:
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_ip_address and payload.")
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, triggering the application loop.
:param timestep: The timestep value to update the bot's state.
"""
self._application_loop()

View File

@@ -7,7 +7,6 @@ from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import patch
import nodeenv
import pytest
from primaite import getLogger

View File

@@ -10,7 +10,7 @@ def test_web_page_home_page(uc2_network):
"""Test to see if the browser is able to open the main page of the web server."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_client.execute()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage("http://arcd.com/") is True
@@ -24,7 +24,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network):
"""Test to see if the client can handle requests with domain names"""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_client.execute()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage("http://arcd.com/users/") is True
@@ -38,7 +38,7 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network):
"""Test to see if the client can handle requests that use ip_address."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_client.execute()
web_server: Server = uc2_network.get_node_by_hostname("web_server")
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address