Merge remote-tracking branch 'origin/dev' into feature/2350-confirm-action-observation-space-conforms-to-CAOS-0.7

This commit is contained in:
Czar Echavez
2024-03-12 09:11:30 +00:00
35 changed files with 781 additions and 232 deletions

View File

@@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Made requests fail to reach their target if the node is off
- Added responses to requests
- Made environment reset completely recreate the game object.
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
- Changed the data manipulation scenario to include a second green agent on client 1.

View File

@@ -6,7 +6,7 @@ SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
AUTOSUMMARY="source\_autosummary"
AUTOSUMMARY="source/_autosummary"
# Remove command is different depending on OS
ifdef OS

View File

@@ -3,7 +3,7 @@
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Request System
==============
**************
``SimComponent`` objects in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``RequestType``.
@@ -12,26 +12,37 @@ Just like other aspects of SimComponent, the request types are not managed centr
- API
When requesting an action within the simulation, these two arguments must be provided:
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '<node-name>', 'service', '<service-name>', 'restart']`.
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as ``['network', 'node', '<node-name>', 'service', '<service-name>', 'restart']``.
2. ``context`` - optional extra information that can be used to decide how to process the request. This is formatted as a dictionary. For example, if the request requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient.
When a request is resolved, it returns a success status, and optional additional data about the request.
``status`` can be one of:
* ``success``: the request was executed
* ``failure``: the request could not be executed
* ``unreachable``: the target for the request was not found
* ``pending``: the request was initiated, but has not finished during this step
``data`` can be a dictionary with any arbitrary JSON-like data to describe the outcome of the request.
- ``request`` detail
The request is a list of strings which help specify who should handle the request. The strings in the request list help RequestManagers traverse the 'ownership tree' of SimComponent. The example given above would be handled in the following way:
1. ``Simulation`` receives `['network', 'node', '<node-name>', 'service', '<service-name>', 'restart']`.
1. ``Simulation`` receives ``['network', 'node', 'computer_1', 'service', 'DNSService', 'restart']``.
The first element of the request is ``network``, therefore it passes the request down to its network.
2. ``Network`` receives `['node', '<node-name>', 'service', '<service-name>', 'restart']`.
2. ``Network`` receives ``['node', 'computer_1', 'service', 'DNSService', 'restart']``.
The first element of the request is ``node``, therefore the network looks at the node name and passes the request down to the node with that name.
3. ``Node`` receives `['service', '<service-name>', 'restart']`.
3. ``computer_1`` receives ``['service', 'DNSService', 'restart']``.
The first element of the request is ``service``, therefore the node looks at the service name and passes the rest of the request to the service with that name.
4. ``Service`` receives ``['restart']``.
4. ``DNSService`` receives ``['restart']``.
Since ``restart`` is a defined request type in the service's own RequestManager, the service performs a restart.
- ``context`` detail
The context is not used by any of the currently implemented components or requests.
Technical Detail
----------------
================
This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.RequestType`, and :py:class:`primaite.simulator.core.RequestManager`.
@@ -93,3 +104,19 @@ An example of how this works is in the :py:class:`primaite.simulator.network.har
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
This process is repeated until the request word corresponds to a callable function rather than another ``RequestManager`` .
Request Validation
------------------
There are times when a request should be rejected. For instance, if an agent attempts to run an application on a node that is currently off. For this purpose, requests are filtered by an object called a validator. :py:class:`primaite.simulator.core.RequestPermissionValidator` is a basic class whose ``__call__()`` method returns ``True`` if the request should be permitted or ``False`` if it cannot be permitted. For example, the Node class has a validator called :py:class:`primaite.simulator.network.hardware.base.Node._NodeIsOnValidator<_NodeIsOnValidator>` which allows requests only when the operating status of the node is ``ON``.
Requests that are specified without a validator automatically get assigned an ``AllowAllValidator`` which allows requests no matter what.
Request Response
----------------
The :py:class:`primaite.interface.request.RequestResponse<RequestResponse>` is a data transfer object that carries response data between the simulator and the game layer. The ``status`` field reports on the success or failure, and the ``data`` field is for any additional data. The most common way that this class is initiated is by its ``from_bool`` method. This way, given a True or False, a successful or failed request response is generated, respectively (with an empty data field).
For instance, the ``execute`` action on a :py:class:`primaite.simulator.system.applications.web_browser.WebBrowser<WebBrowser>` calls the ``get_webpage()`` method of the ``WebBrowser``. ``get_webpage()`` returns a True if the webpage was successfully retrieved, and False if unsuccessful for any reason, such as being blocked by an ACL, or if the database server is unresponsive. The boolean returned from ``get_webpage()`` is used to create the request response.
Just as the requests themselves were passed from owner to component, the request response is bubbled back up from component to owner until it arrives at the game layer.

View File

@@ -5,9 +5,9 @@
Simulation State
================
``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the childrens' own ``describe_state`` methods.
``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the children's own ``describe_state`` methods.
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objetcs must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objects must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
This code snippet demonstrates how the state information is defined within the ``SimComponent`` class:

View File

@@ -492,9 +492,9 @@ class NetworkACLAddRuleAction(AbstractAction):
"add_rule",
permission_str,
protocol,
src_ip,
str(src_ip),
src_port,
dst_ip,
str(dst_ip),
dst_port,
position,
]

View File

@@ -170,9 +170,13 @@ class PrimaiteGame:
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
response = self.simulation.apply_request(request)
agent_actions[agent.agent_name] = {
"action": action_choice,
"parameters": options,
"response": response.model_dump(),
}
return agent_actions
def advance_timestep(self) -> None:

View File

View File

@@ -0,0 +1,44 @@
from typing import Dict, ForwardRef, Literal
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
RequestResponse = ForwardRef("RequestResponse")
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
class RequestResponse(BaseModel):
"""Schema for generic request responses."""
model_config = ConfigDict(extra="forbid", strict=True)
"""Cannot have extra fields in the response. Anything custom goes into the data field."""
status: Literal["pending", "success", "failure", "unreachable"] = "pending"
"""
What is the current status of the request:
- pending - the request has not been received yet, or it has been received but it's still being processed.
- success - the request has been received and executed successfully.
- failure - the request has been received and attempted, but execution failed.
- unreachable - the request could not reach it's intended target, either because it doesn't exist or the target
is off.
"""
data: Dict = {}
"""Catch-all place to provide any additional data that was generated as a response to the request."""
# TODO: currently, status and data have default values, because I don't want to interrupt existing functionality too
# much. However, in the future we might consider making them mandatory.
@classmethod
@validate_call
def from_bool(cls, status_bool: StrictBool) -> RequestResponse:
"""
Construct a basic request response from a boolean.
True maps to a success status. False maps to a failure status.
:param status_bool: Whether to create a successful response
:type status_bool: bool
"""
if status_bool is True:
return cls(status="success", data={})
elif status_bool is False:
return cls(status="failure", data={})

View File

@@ -426,13 +426,13 @@
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info[0]\n",
" red_action = red_info['action']\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str"
" return red_str\n"
]
},
{
@@ -492,7 +492,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent."
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) increased from 0 to 1, but only right after the red attack, so we probably cannot see it now."
]
},
{
@@ -510,9 +510,9 @@
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
"print(f\"Blue reward:{reward}\" )"
]
},
@@ -535,7 +535,7 @@
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
"print(f\"Blue reward:{reward:.2f}\" )"
@@ -557,17 +557,17 @@
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(50) # Block client 1\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(51) # Block client 2\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"\n",
"for step in range(30):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
]
},
{
@@ -606,20 +606,35 @@
"metadata": {},
"outputs": [],
"source": [
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's unblock client 2\n",
" env.step(58) # remove ACL rule 6\n",
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" env.step(57) # remove ACL rule 5\n",
"else:\n",
" print(\"something went wrong, neither client has NMNEs\")"
"env.step(58) # Remove the ACL rule that blocks client 1\n",
"env.step(57) # Remove the ACL rule that blocks client 2\n",
"\n",
"tries = 0\n",
"while True:\n",
" tries += 1\n",
" obs, reward, terminated, truncated, info = env.step(0)\n",
"\n",
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's block it\n",
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
" break\n",
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 2 has NMNEs, so let's block it\n",
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
" break\n",
" if tries>100:\n",
" print(\"Error: NMNE never increased\")\n",
" break\n",
"\n",
"env.step(13) # Patch the database\n",
"..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks."
"Now, the reward will eventually increase to 0.9, even after red agent attempts subsequent attacks."
]
},
{
@@ -628,9 +643,10 @@
"metadata": {},
"outputs": [],
"source": [
"for step in range(30):\n",
"\n",
"for step in range(40):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
]
},
{
@@ -648,6 +664,13 @@
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -93,7 +93,7 @@ class PrimaiteIO:
{
"episode": episode,
"timestep": timestep,
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
"agent_actions": agent_actions,
}
]
)

View File

@@ -1,15 +1,18 @@
# flake8: noqa
"""Core of the PrimAITE Simulator."""
from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.interface.request import RequestResponse
_LOGGER = getLogger(__name__)
RequestFormat = List[Union[str, int, float]]
class RequestPermissionValidator(BaseModel):
"""
@@ -21,15 +24,15 @@ class RequestPermissionValidator(BaseModel):
"""
@abstractmethod
def __call__(self, request: List[str], context: Dict) -> bool:
"""Use the request and context paramters to decide whether the request should be permitted."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Use the request and context parameters to decide whether the request should be permitted."""
pass
class AllowAllValidator(RequestPermissionValidator):
"""Always allows the request."""
def __call__(self, request: List[str], context: Dict) -> bool:
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Always allow the request."""
return True
@@ -42,7 +45,7 @@ class RequestType(BaseModel):
the request can be performed or not.
"""
func: Callable[[List[str], Dict], None]
func: Callable[[RequestFormat, Dict], RequestResponse]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
@@ -71,7 +74,7 @@ class RequestManager(BaseModel):
request_types: Dict[str, RequestType] = {}
"""maps request name to an RequestType object."""
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
def __call__(self, request: RequestFormat, context: Dict) -> RequestResponse:
"""
Process an request request.
@@ -84,23 +87,23 @@ class RequestManager(BaseModel):
:raises RuntimeError: If the request parameter does not have a valid request name as the first item.
"""
request_key = request[0]
request_options = request[1:]
if request_key not in self.request_types:
msg = (
f"Request {request} could not be processed because {request_key} is not a valid request name",
"within this RequestManager",
)
_LOGGER.error(msg)
raise RuntimeError(msg)
_LOGGER.debug(msg)
return RequestResponse(status="unreachable", data={"reason": msg})
request_type = self.request_types[request_key]
request_options = request[1:]
if not request_type.validator(request_options, context):
_LOGGER.debug(f"Request {request} was denied due to insufficient permissions")
return
return RequestResponse(status="failure", data={"reason": "request validation failed"})
request_type.func(request_options, context)
return request_type.func(request_options, context)
def add_request(self, name: str, request_type: RequestType) -> None:
"""
@@ -202,7 +205,8 @@ class SimComponent(BaseModel):
}
return state
def apply_request(self, request: List[str], context: Dict = {}) -> None:
@validate_call
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
"""
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
@@ -222,7 +226,7 @@ class SimComponent(BaseModel):
"""
if self._request_manager is None:
return
self._request_manager(request, context)
return self._request_manager(request, context)
def apply_timestep(self, timestep: int) -> None:
"""

View File

@@ -80,6 +80,11 @@ class DomainController(SimComponent):
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# Action 'account' matches requests like:
# ['account', '<account-uuid>', *account_action]
@@ -87,6 +92,7 @@ class DomainController(SimComponent):
"account",
RequestType(
func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context),
# TODO: not sure what should get returned here, revisit
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
),
)

View File

@@ -114,16 +114,17 @@ class File(FileSystemItemABC):
state["num_access"] = self.num_access
return state
def scan(self) -> None:
def scan(self) -> bool:
"""Updates the visible statuses of the file."""
if self.deleted:
self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}")
return
return False
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
self.visible_health_status = self.health_status
return True
def reveal_to_red(self) -> None:
"""Reveals the folder/file to the red agent."""
@@ -132,7 +133,7 @@ class File(FileSystemItemABC):
return
self.revealed_to_red = True
def check_hash(self) -> None:
def check_hash(self) -> bool:
"""
Check if the file has been changed.
@@ -142,7 +143,7 @@ class File(FileSystemItemABC):
"""
if self.deleted:
self.sys_log.error(f"Unable to check hash of deleted file {self.folder_name}/{self.name}")
return
return False
current_hash = None
# if file is real, read the file contents
@@ -164,12 +165,13 @@ class File(FileSystemItemABC):
# if the previous hash and current hash do not match, mark file as corrupted
if self.previous_hash is not current_hash:
self.corrupt()
return True
def repair(self) -> None:
def repair(self) -> bool:
"""Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD."""
if self.deleted:
self.sys_log.error(f"Unable to repair deleted file {self.folder_name}/{self.name}")
return
return False
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
@@ -178,12 +180,13 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
return True
def corrupt(self) -> None:
def corrupt(self) -> bool:
"""Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT."""
if self.deleted:
self.sys_log.error(f"Unable to corrupt deleted file {self.folder_name}/{self.name}")
return
return False
# set file status to good if corrupt
if self.health_status == FileSystemItemHealthStatus.GOOD:
@@ -192,12 +195,13 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
return True
def restore(self) -> None:
def restore(self) -> bool:
"""Determines if the file needs to be repaired or unmarked as deleted."""
if self.deleted:
self.deleted = False
return
return True
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.GOOD
@@ -205,13 +209,15 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
return True
def delete(self):
def delete(self) -> bool:
"""Marks the file as deleted."""
if self.deleted:
self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}")
return
return False
self.num_access += 1 # file was accessed
self.deleted = True
self.sys_log.info(f"File deleted {self.folder_name}/{self.name}")
return True

View File

@@ -7,6 +7,7 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_type import FileType
@@ -39,18 +40,27 @@ class FileSystem(SimComponent):
self.create_folder("root")
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
self._delete_manager = RequestManager()
self._delete_manager.add_request(
name="file",
request_type=RequestType(
func=lambda request, context: self.delete_file(folder_name=request[0], file_name=request[1])
func=lambda request, context: RequestResponse.from_bool(
self.delete_file(folder_name=request[0], file_name=request[1])
)
),
)
self._delete_manager.add_request(
name="folder",
request_type=RequestType(func=lambda request, context: self.delete_folder(folder_name=request[0])),
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0]))
),
)
rm.add_request(
name="delete",
@@ -61,12 +71,16 @@ class FileSystem(SimComponent):
self._restore_manager.add_request(
name="file",
request_type=RequestType(
func=lambda request, context: self.restore_file(folder_name=request[0], file_name=request[1])
func=lambda request, context: RequestResponse.from_bool(
self.restore_file(folder_name=request[0], file_name=request[1])
)
),
)
self._restore_manager.add_request(
name="folder",
request_type=RequestType(func=lambda request, context: self.restore_folder(folder_name=request[0])),
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self.restore_folder(folder_name=request[0]))
),
)
rm.add_request(
name="restore",
@@ -142,7 +156,7 @@ class FileSystem(SimComponent):
)
return folder
def delete_folder(self, folder_name: str):
def delete_folder(self, folder_name: str) -> bool:
"""
Deletes a folder, removes it from the folders list and removes any child folders and files.
@@ -150,24 +164,26 @@ class FileSystem(SimComponent):
"""
if folder_name == "root":
self.sys_log.warning("Cannot delete the root folder.")
return
return False
folder = self.get_folder(folder_name)
if folder:
# set folder to deleted state
folder.delete()
# remove from folder list
self.folders.pop(folder.uuid)
# add to deleted list
folder.remove_all_files()
self.deleted_folders[folder.uuid] = folder
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
else:
if not folder:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
return False
def delete_folder_by_id(self, folder_uuid: str):
# set folder to deleted state
folder.delete()
# remove from folder list
self.folders.pop(folder.uuid)
# add to deleted list
folder.remove_all_files()
self.deleted_folders[folder.uuid] = folder
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
return True
def delete_folder_by_id(self, folder_uuid: str) -> None:
"""
Deletes a folder via its uuid.
@@ -303,7 +319,7 @@ class FileSystem(SimComponent):
return file
def delete_file(self, folder_name: str, file_name: str):
def delete_file(self, folder_name: str, file_name: str) -> bool:
"""
Delete a file by its name from a specific folder.
@@ -317,8 +333,10 @@ class FileSystem(SimComponent):
# increment file creation
self.num_file_deletions += 1
folder.remove_file(file)
return True
return False
def delete_file_by_id(self, folder_uuid: str, file_uuid: str):
def delete_file_by_id(self, folder_uuid: str, file_uuid: str) -> None:
"""
Deletes a file via its uuid.
@@ -335,7 +353,7 @@ class FileSystem(SimComponent):
else:
self.sys_log.error(f"Unable to delete file that does not exist. (id: {file_uuid})")
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str) -> None:
"""
Move a file from one folder to another.
@@ -423,7 +441,7 @@ class FileSystem(SimComponent):
# Agent actions
###############################################################
def scan(self, instant_scan: bool = False):
def scan(self, instant_scan: bool = False) -> None:
"""
Scan all the folders (and child files) in the file system.
@@ -432,7 +450,7 @@ class FileSystem(SimComponent):
for folder_id in self.folders:
self.folders[folder_id].scan(instant_scan=instant_scan)
def reveal_to_red(self, instant_scan: bool = False):
def reveal_to_red(self, instant_scan: bool = False) -> None:
"""
Reveals all the folders (and child files) in the file system to the red agent.
@@ -441,7 +459,7 @@ class FileSystem(SimComponent):
for folder_id in self.folders:
self.folders[folder_id].reveal_to_red(instant_scan=instant_scan)
def restore_folder(self, folder_name: str):
def restore_folder(self, folder_name: str) -> bool:
"""
Restore a folder.
@@ -454,13 +472,14 @@ class FileSystem(SimComponent):
if folder is None:
self.sys_log.error(f"Unable to restore folder {folder_name}. Folder is not in deleted folder list.")
return
return False
self.deleted_folders.pop(folder.uuid, None)
folder.restore()
self.folders[folder.uuid] = folder
return True
def restore_file(self, folder_name: str, file_name: str):
def restore_file(self, folder_name: str, file_name: str) -> bool:
"""
Restore a file.
@@ -473,12 +492,15 @@ class FileSystem(SimComponent):
:type: file_name: str
"""
folder = self.get_folder(folder_name=folder_name)
if not folder:
_LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.")
return False
if folder:
file = folder.get_file(file_name=file_name, include_deleted=True)
file = folder.get_file(file_name=file_name, include_deleted=True)
if file is None:
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
return
if not file:
msg = f"Unable to restore file {file_name}. File was not found."
self.sys_log.error(msg)
return False
folder.restore_file(file_name=file_name)
return folder.restore_file(file_name=file_name)

View File

@@ -6,6 +6,7 @@ from enum import Enum
from typing import Dict, Optional
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.system.core.sys_log import SysLog
@@ -100,14 +101,33 @@ class FileSystemItemABC(SimComponent):
return state
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan()))
rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash()))
rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair()))
rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore()))
rm.add_request(
name="scan", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
)
rm.add_request(
name="checkhash",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.check_hash())),
)
rm.add_request(
name="repair",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.repair())),
)
rm.add_request(
name="restore",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.restore())),
)
rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt()))
rm.add_request(
name="corrupt",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.corrupt())),
)
return rm
@@ -124,9 +144,9 @@ class FileSystemItemABC(SimComponent):
return convert_size(self.size)
@abstractmethod
def scan(self) -> None:
def scan(self) -> bool:
"""Scan the folder/file - updates the visible_health_status."""
pass
return False
@abstractmethod
def reveal_to_red(self) -> None:
@@ -134,7 +154,7 @@ class FileSystemItemABC(SimComponent):
pass
@abstractmethod
def check_hash(self) -> None:
def check_hash(self) -> bool:
"""
Checks the has of the file to detect any changes.
@@ -142,30 +162,30 @@ class FileSystemItemABC(SimComponent):
Return False if corruption is detected, otherwise True
"""
pass
return False
@abstractmethod
def repair(self) -> None:
def repair(self) -> bool:
"""
Repair the FileSystemItem.
True if successfully repaired. False otherwise.
"""
pass
return False
@abstractmethod
def corrupt(self) -> None:
def corrupt(self) -> bool:
"""
Corrupt the FileSystemItem.
True if successfully corrupted. False otherwise.
"""
pass
return False
@abstractmethod
def restore(self) -> None:
def restore(self) -> bool:
"""Restore the file/folder to the state before it got ruined."""
pass
return False
@abstractmethod
def delete(self) -> None:

View File

@@ -5,6 +5,7 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
@@ -50,10 +51,17 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
name="delete",
request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])),
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self.remove_file_by_name(file_name=request[0]))
),
)
self._file_request_manager = RequestManager()
rm.add_request(
@@ -250,6 +258,21 @@ class Folder(FileSystemItemABC):
file = self.get_file_by_id(file_uuid=file_uuid)
self.remove_file(file=file)
def remove_file_by_name(self, file_name: str) -> bool:
"""
Remove a file using its name.
:param file_name: filename
:type file_name: str
:return: Whether it was successfully removed.
:rtype: bool
"""
for f in self.files.values():
if f.name == file_name:
self.remove_file(f)
return True
return False
def remove_all_files(self):
"""Removes all the files in the folder."""
for file_id in self.files:
@@ -259,7 +282,7 @@ class Folder(FileSystemItemABC):
self.files = {}
def restore_file(self, file_name: str):
def restore_file(self, file_name: str) -> bool:
"""
Restores a file.
@@ -269,13 +292,14 @@ class Folder(FileSystemItemABC):
file = self.get_file(file_name=file_name, include_deleted=True)
if not file:
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
return
return False
file.restore()
self.files[file.uuid] = file
if file.deleted:
self.deleted_files.pop(file.uuid)
return True
def quarantine(self):
"""Quarantines the File System Folder."""
@@ -289,7 +313,7 @@ class Folder(FileSystemItemABC):
"""Returns true if the folder is being quarantined."""
pass
def scan(self, instant_scan: bool = False) -> None:
def scan(self, instant_scan: bool = False) -> bool:
"""
Update Folder visible status.
@@ -297,7 +321,7 @@ class Folder(FileSystemItemABC):
"""
if self.deleted:
self.sys_log.error(f"Unable to scan deleted folder {self.name}")
return
return False
if instant_scan:
for file_id in self.files:
@@ -305,7 +329,7 @@ class Folder(FileSystemItemABC):
file.scan()
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
return
return True
if self.scan_countdown <= 0:
# scan one file per timestep
@@ -314,6 +338,7 @@ class Folder(FileSystemItemABC):
else:
# scan already in progress
self.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})")
return True
def reveal_to_red(self, instant_scan: bool = False):
"""
@@ -340,7 +365,7 @@ class Folder(FileSystemItemABC):
# scan already in progress
self.sys_log.info(f"Red Agent Scan is already in progress {self.name} (id: {self.uuid})")
def check_hash(self) -> None:
def check_hash(self) -> bool:
"""
Runs a :func:`check_hash` on all files in the folder.
@@ -353,7 +378,7 @@ class Folder(FileSystemItemABC):
"""
if self.deleted:
self.sys_log.error(f"Unable to check hash of deleted folder {self.name}")
return
return False
# iterate through the files and run a check hash
no_corrupted_files = True
@@ -369,12 +394,13 @@ class Folder(FileSystemItemABC):
self.corrupt()
self.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})")
return True
def repair(self) -> None:
def repair(self) -> bool:
"""Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD."""
if self.deleted:
self.sys_log.error(f"Unable to repair deleted folder {self.name}")
return
return False
# iterate through the files in the folder
for file_id in self.files:
@@ -388,8 +414,9 @@ class Folder(FileSystemItemABC):
self.health_status = FileSystemItemHealthStatus.GOOD
self.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})")
return True
def restore(self) -> None:
def restore(self) -> bool:
"""
If a Folder is corrupted, run a repair on the folder and its child files.
@@ -405,12 +432,13 @@ class Folder(FileSystemItemABC):
else:
# scan already in progress
self.sys_log.info(f"Folder restoration already in progress {self.name} (id: {self.uuid})")
return True
def corrupt(self) -> None:
def corrupt(self) -> bool:
"""Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT."""
if self.deleted:
self.sys_log.error(f"Unable to corrupt deleted folder {self.name}")
return
return False
# iterate through the files in the folder
for file_id in self.files:
@@ -421,11 +449,13 @@ class Folder(FileSystemItemABC):
self.health_status = FileSystemItemHealthStatus.CORRUPT
self.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})")
return True
def delete(self):
def delete(self) -> bool:
"""Marks the file as deleted. Prevents agent actions from occuring."""
if self.deleted:
self.sys_log.error(f"Unable to delete an already deleted folder {self.name}")
return
return False
self.deleted = True
return True

View File

@@ -61,6 +61,11 @@ class Network(SimComponent):
software.run()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()
rm.add_request(

View File

@@ -12,8 +12,9 @@ from pydantic import BaseModel, Field
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -113,10 +114,15 @@ class NetworkInterface(SimComponent, ABC):
self.enable()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
return rm
@@ -140,14 +146,16 @@ class NetworkInterface(SimComponent, ABC):
return state
@abstractmethod
def enable(self):
def enable(self) -> bool:
"""Enable the interface."""
pass
return False
@abstractmethod
def disable(self):
def disable(self) -> bool:
"""Disable the interface."""
pass
return False
def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None:
"""
@@ -257,10 +265,9 @@ class NetworkInterface(SimComponent, ABC):
"""
Apply a timestep evolution to this component.
This just clears the nmne count back to 0.tests/integration_tests/network/test_capture_nmne.py
This just clears the nmne count back to 0.
"""
super().apply_timestep(timestep=timestep)
self.nmne.clear()
class WiredNetworkInterface(NetworkInterface, ABC):
@@ -284,24 +291,24 @@ class WiredNetworkInterface(NetworkInterface, ABC):
_connected_link: Optional[Link] = None
"The network link to which the network interface is connected."
def enable(self):
def enable(self) -> bool:
"""Attempt to enable the network interface."""
if self.enabled:
return
return True
if not self._connected_node:
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
return
return False
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.info(
f"Interface {self} cannot be enabled as the connected Node is not powered on"
)
return
return False
if not self._connected_link:
self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.")
return
return False
self.enabled = True
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
@@ -310,11 +317,12 @@ class WiredNetworkInterface(NetworkInterface, ABC):
)
if self._connected_link:
self._connected_link.endpoint_up()
return True
def disable(self):
def disable(self) -> bool:
"""Disable the network interface."""
if not self.enabled:
return
return True
self.enabled = False
if self._connected_node:
self._connected_node.sys_log.info(f"Network Interface {self} disabled")
@@ -322,6 +330,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
_LOGGER.debug(f"Interface {self} disabled")
if self._connected_link:
self._connected_link.endpoint_down()
return True
def connect_link(self, link: Link):
"""
@@ -496,7 +505,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
return state
def enable(self):
def enable(self) -> bool:
"""
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
@@ -512,8 +521,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
try:
pass
self._connected_node.default_gateway_hello()
return True
except AttributeError:
pass
return False
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
@@ -765,35 +776,74 @@ class Node(SimComponent):
self.sys_log.current_episode = episode
self.sys_log.setup_logger()
class _NodeIsOnValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the node is on.
This is useful because no actions should be being resolved if the node is off.
"""
node: Node
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the node is on or off."""
return self.node.operating_state == NodeOperatingState.ON
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
_node_is_on = Node._NodeIsOnValidator(node=self)
rm = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
self._service_request_manager = RequestManager()
rm.add_request("service", RequestType(func=self._service_request_manager))
rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on))
self._nic_request_manager = RequestManager()
rm.add_request("network_interface", RequestType(func=self._nic_request_manager))
rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on))
rm.add_request("file_system", RequestType(func=self.file_system._request_manager))
rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on))
# currently we don't have any applications nor processes, so these will be empty
self._process_request_manager = RequestManager()
rm.add_request("process", RequestType(func=self._process_request_manager))
rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on))
self._application_request_manager = RequestManager()
rm.add_request("application", RequestType(func=self._application_request_manager))
rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on))
rm.add_request("scan", RequestType(func=lambda request, context: self.reveal_to_red()))
rm.add_request(
"scan",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on
),
)
rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
rm.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
rm.add_request("reset", RequestType(func=lambda request, context: self.reset())) # TODO implement node reset
rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
rm.add_request(
"shutdown",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
),
)
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
rm.add_request(
"reset",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),
) # TODO implement node reset
rm.add_request(
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
) # TODO implement logon request
rm.add_request(
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
) # TODO implement logoff request
self._os_request_manager = RequestManager()
self._os_request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan()))
rm.add_request("os", RequestType(func=self._os_request_manager))
self._os_request_manager.add_request(
"scan",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on),
)
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
return rm
@@ -973,7 +1023,7 @@ class Node(SimComponent):
self.file_system.apply_timestep(timestep=timestep)
def scan(self) -> None:
def scan(self) -> bool:
"""
Scan the node and all the items within it.
@@ -987,8 +1037,9 @@ class Node(SimComponent):
to the red agent.
"""
self.node_scan_countdown = self.node_scan_duration
return True
def reveal_to_red(self) -> None:
def reveal_to_red(self) -> bool:
"""
Reveals the node and all the items within it to the red agent.
@@ -1002,34 +1053,40 @@ class Node(SimComponent):
`revealed_to_red` to `True`.
"""
self.red_scan_countdown = self.node_scan_duration
return True
def power_on(self):
def power_on(self) -> bool:
"""Power on the Node, enabling its NICs if it is in the OFF state."""
if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration
if self.start_up_duration <= 0:
self.operating_state = NodeOperatingState.ON
self._start_up_actions()
self.sys_log.info("Power on")
for network_interface in self.network_interfaces.values():
network_interface.enable()
return True
if self.operating_state == NodeOperatingState.OFF:
self.operating_state = NodeOperatingState.BOOTING
self.start_up_countdown = self.start_up_duration
return True
def power_off(self):
return False
def power_off(self) -> bool:
"""Power off the Node, disabling its NICs if it is in the ON state."""
if self.shut_down_duration <= 0:
self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Power off")
return True
if self.operating_state == NodeOperatingState.ON:
for network_interface in self.network_interfaces.values():
network_interface.disable()
self.operating_state = NodeOperatingState.SHUTTING_DOWN
self.shut_down_countdown = self.shut_down_duration
return True
return False
if self.shut_down_duration <= 0:
self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Power off")
def reset(self):
def reset(self) -> bool:
"""
Resets the node.
@@ -1040,6 +1097,8 @@ class Node(SimComponent):
self.is_resetting = True
self.sys_log.info("Resetting")
self.power_off()
return True
return False
def connect_nic(self, network_interface: NetworkInterface, port_name: Optional[str] = None):
"""

View File

@@ -51,13 +51,15 @@ class WirelessAccessPoint(IPWirelessNetworkInterface):
return state
def enable(self):
def enable(self) -> bool:
"""Enable the interface."""
pass
return True
def disable(self):
def disable(self) -> bool:
"""Disable the interface."""
pass
return True
def send_frame(self, frame: Frame) -> bool:
"""

View File

@@ -48,13 +48,15 @@ class WirelessNIC(IPWirelessNetworkInterface):
return state
def enable(self):
def enable(self) -> bool:
"""Enable the interface."""
pass
return True
def disable(self):
def disable(self) -> bool:
"""Disable the interface."""
pass
return True
def send_frame(self, frame: Frame) -> bool:
"""

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -293,6 +294,11 @@ class AccessControlList(SimComponent):
self._acl = [None] * (self.max_acl_rules - 1)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
# TODO: Add src and dst wildcard masks as positional args in this request.
rm = super()._init_request_manager()
@@ -308,19 +314,24 @@ class AccessControlList(SimComponent):
rm.add_request(
"add_rule",
RequestType(
func=lambda request, context: self.add_rule(
action=ACLAction[request[0]],
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
src_port=None if request[3] == "ALL" else Port[request[3]],
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
dst_port=None if request[5] == "ALL" else Port[request[5]],
position=int(request[6]),
func=lambda request, context: RequestResponse.from_bool(
self.add_rule(
action=ACLAction[request[0]],
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
src_port=None if request[3] == "ALL" else Port[request[3]],
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
dst_port=None if request[5] == "ALL" else Port[request[5]],
position=int(request[6]),
)
)
),
)
rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
rm.add_request(
"remove_rule",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.remove_rule(int(request[0])))),
)
return rm
def describe_state(self) -> Dict:
@@ -366,7 +377,7 @@ class AccessControlList(SimComponent):
src_port: Optional[Port] = None,
dst_port: Optional[Port] = None,
position: int = 0,
) -> None:
) -> bool:
"""
Adds a new ACL rule to control network traffic based on specified criteria.
@@ -423,10 +434,12 @@ class AccessControlList(SimComponent):
src_port=src_port,
dst_port=dst_port,
)
return True
else:
raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.")
return False
def remove_rule(self, position: int) -> None:
def remove_rule(self, position: int) -> bool:
"""
Remove an ACL rule from a specific position.
@@ -437,8 +450,10 @@ class AccessControlList(SimComponent):
rule = self._acl[position] # noqa
self._acl[position] = None
del rule
return True
else:
raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.")
return False
def is_permitted(self, frame: Frame) -> Tuple[bool, ACLRule]:
"""Check if a packet with the given properties is permitted through the ACL."""
@@ -1082,6 +1097,11 @@ class Router(NetworkNode):
super().setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))
return rm

View File

@@ -146,9 +146,12 @@ def arcd_uc2_network() -> Network:
)
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
db_client_1 = client_1.software_manager.install(DatabaseClient)
db_client_1 = client_1.software_manager.software.get("DatabaseClient")
client_1.software_manager.install(DatabaseClient)
db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14"))
db_client_1.run()
web_browser_1 = client_1.software_manager.software.get("WebBrowser")
web_browser_1.target_url = "http://arcd.com/users/"
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
db_manipulation_bot.configure(
@@ -170,9 +173,10 @@ def arcd_uc2_network() -> Network:
client_2.power_on()
client_2.software_manager.install(DatabaseClient)
db_client_2 = client_2.software_manager.software.get("DatabaseClient")
db_client_2.configure(server_ip_address=IPv4Address("192.168.1.14"))
db_client_2.run()
web_browser = client_2.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
web_browser_2 = client_2.software_manager.software.get("WebBrowser")
web_browser_2.target_url = "http://arcd.com/users/"
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
# Domain Controller

View File

@@ -1,5 +1,6 @@
from typing import Dict
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.controller import DomainController
from primaite.simulator.network.container import Network
@@ -26,12 +27,18 @@ class Simulation(SimComponent):
self.network.setup_for_episode(episode=episode)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# pass through network requests to the network objects
rm.add_request("network", RequestType(func=self.network._request_manager))
# pass through domain requests to the domain object
rm.add_request("domain", RequestType(func=self.domain._request_manager))
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
# if 'do_nothing' is requested, just return a success
rm.add_request("do_nothing", RequestType(func=lambda request, context: RequestResponse(status="success")))
return rm
def describe_state(self) -> Dict:

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
from uuid import uuid4
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
@@ -36,8 +37,13 @@ class DatabaseClient(Application):
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("execute", RequestType(func=lambda request, context: self.execute()))
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
return rm
def execute(self) -> bool:

View File

@@ -4,6 +4,7 @@ from typing import Dict, Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
@@ -74,9 +75,17 @@ class DataManipulationBot(Application):
return db_client
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
rm.add_request(
name="execute",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
)
return rm
@@ -179,23 +188,22 @@ class DataManipulationBot(Application):
"""
super().run()
def attack(self):
def attack(self) -> bool:
"""Perform the attack steps after opening the application."""
if not self._can_perform_action():
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
self.run()
self._application_loop()
return self._application_loop()
def _application_loop(self):
def _application_loop(self) -> bool:
"""
The main application loop of the bot, handling the attack process.
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if not self._can_perform_action():
return
self.num_executions += 1
self.num_executions += 1
return False
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self._logon()
@@ -207,8 +215,12 @@ class DataManipulationBot(Application):
DataManipulationAttackStage.FAILED,
):
self.attack_stage = DataManipulationAttackStage.NOT_STARTED
return True
else:
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
return False
def apply_timestep(self, timestep: int) -> None:
"""

View File

@@ -4,6 +4,7 @@ from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -57,9 +58,17 @@ class DoSBot(DatabaseClient):
self.max_sessions = 1000 # override normal max sessions
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
rm.add_request(
name="execute",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())),
)
return rm
@@ -97,26 +106,26 @@ class DoSBot(DatabaseClient):
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
)
def run(self):
def run(self) -> bool:
"""Run the Denial of Service Bot."""
super().run()
self._application_loop()
return self._application_loop()
def _application_loop(self):
def _application_loop(self) -> bool:
"""
The main application loop for the Denial of Service bot.
The loop goes through the stages of a DoS attack.
"""
if not self._can_perform_action():
return
return False
# DoS bot cannot do anything without a target
if not self.target_ip_address or not self.target_port:
self.sys_log.error(
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
)
return
return True
self.clear_connections()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
@@ -126,6 +135,7 @@ class DoSBot(DatabaseClient):
self.attack_stage = DoSAttackStage.NOT_STARTED
else:
self.attack_stage = DoSAttackStage.COMPLETED
return True
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
"""

View File

@@ -6,6 +6,7 @@ from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
@@ -50,9 +51,17 @@ class WebBrowser(Application):
self.run()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa
name="execute",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self.get_webpage())
), # noqa
)
return rm

View File

@@ -3,6 +3,7 @@ from enum import Enum
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
@@ -79,15 +80,20 @@ class Service(IOSoftware):
return super().receive(payload=payload, session_id=session_id, **kwargs)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
rm.add_request("stop", RequestType(func=lambda request, context: self.stop()))
rm.add_request("start", RequestType(func=lambda request, context: self.start()))
rm.add_request("pause", RequestType(func=lambda request, context: self.pause()))
rm.add_request("resume", RequestType(func=lambda request, context: self.resume()))
rm.add_request("restart", RequestType(func=lambda request, context: self.restart()))
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop())))
rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start())))
rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause())))
rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume())))
rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart())))
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
return rm
@abstractmethod
@@ -106,17 +112,19 @@ class Service(IOSoftware):
state["health_state_visible"] = self.health_state_visible.value
return state
def stop(self) -> None:
def stop(self) -> bool:
"""Stop the service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Stopping service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
return True
return False
def start(self, **kwargs) -> None:
def start(self, **kwargs) -> bool:
"""Start the service."""
# cant start the service if the node it is on is off
if not super()._can_perform_action():
return
return False
if self.operating_state == ServiceOperatingState.STOPPED:
self.sys_log.info(f"Starting service {self.name}")
@@ -124,36 +132,47 @@ class Service(IOSoftware):
# set software health state to GOOD if initially set to UNUSED
if self.health_state_actual == SoftwareHealthState.UNUSED:
self.set_health_state(SoftwareHealthState.GOOD)
return True
return False
def pause(self) -> None:
def pause(self) -> bool:
"""Pause the service."""
if self.operating_state == ServiceOperatingState.RUNNING:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.PAUSED
return True
return False
def resume(self) -> None:
def resume(self) -> bool:
"""Resume paused service."""
if self.operating_state == ServiceOperatingState.PAUSED:
self.sys_log.info(f"Resuming service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
return True
return False
def restart(self) -> None:
def restart(self) -> bool:
"""Restart running service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.RESTARTING
self.restart_countdown = self.restart_duration
return True
return False
def disable(self) -> None:
def disable(self) -> bool:
"""Disable the service."""
self.sys_log.info(f"Disabling Application {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
return True
def enable(self) -> None:
def enable(self) -> bool:
"""Enable the disabled service."""
if self.operating_state == ServiceOperatingState.DISABLED:
self.sys_log.info(f"Enabling Application {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
return True
return False
def apply_timestep(self, timestep: int) -> None:
"""

View File

@@ -5,6 +5,7 @@ from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from primaite.interface.request import RequestResponse
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -101,20 +102,27 @@ class Software(SimComponent):
"Current number of ticks left to patch the software."
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
"compromise",
RequestType(
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
func=lambda request, context: RequestResponse.from_bool(
self.set_health_state(SoftwareHealthState.COMPROMISED)
),
),
)
rm.add_request(
"patch",
RequestType(
func=lambda request, context: self.patch(),
func=lambda request, context: RequestResponse.from_bool(self.patch()),
),
)
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
return rm
def _get_session_details(self, session_id: str) -> Session:
@@ -148,7 +156,7 @@ class Software(SimComponent):
)
return state
def set_health_state(self, health_state: SoftwareHealthState) -> None:
def set_health_state(self, health_state: SoftwareHealthState) -> bool:
"""
Assign a new health state to this software.
@@ -160,6 +168,7 @@ class Software(SimComponent):
:type health_state: SoftwareHealthState
"""
self.health_state_actual = health_state
return True
def install(self) -> None:
"""
@@ -180,15 +189,18 @@ class Software(SimComponent):
"""
pass
def scan(self) -> None:
def scan(self) -> bool:
"""Update the observed health status to match the actual health status."""
self.health_state_visible = self.health_state_actual
return True
def patch(self) -> None:
def patch(self) -> bool:
"""Perform a patch on the software."""
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
self._patching_countdown = self.patching_duration
self.set_health_state(SoftwareHealthState.PATCHING)
return True
return False
def _update_patch_status(self) -> None:
"""Update the patch status of the software."""

View File

@@ -12,6 +12,8 @@ def test_passing_actions_down(monkeypatch) -> None:
sim = Simulation()
pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0")
pc1.start_up_duration = 0
pc1.power_on()
pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0")
srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0")
s1 = Switch(hostname="switch1")

View File

@@ -125,8 +125,8 @@ def test_describe_state_nmne(uc2_network):
web_server_nic_state = web_server_nic.describe_state()
db_server_nic_state = db_server_nic.describe_state()
uc2_network.apply_timestep(timestep=0)
assert web_server_nic_state["nmne"] == {}
assert db_server_nic_state["nmne"] == {}
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
# Perform another "DELETE" query
db_client.query("DELETE")
@@ -135,8 +135,8 @@ def test_describe_state_nmne(uc2_network):
web_server_nic_state = web_server_nic.describe_state()
db_server_nic_state = db_server_nic.describe_state()
uc2_network.apply_timestep(timestep=0)
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}}
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}}
def test_capture_nmne_observations(uc2_network):

View File

@@ -0,0 +1,160 @@
# some test cases:
# 0. test that sending a request to a valid target results in a success
# 1. test that sending a request to a component that doesn't exist results in a failure
# 2. test that sending a request to a software on a turned-off component results in a failure
# 3. test every implemented action under several different situation, some of which should lead to a success and some to a failure.
import pytest
from primaite.interface.request import RequestResponse
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests.conftest import TestApplication, TestService
def test_successful_application_requests(example_network):
net = example_network
client_1 = net.get_node_by_hostname("client_1")
client_1.software_manager.install(TestApplication)
client_1.software_manager.software.get("TestApplication").run()
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
assert resp_1 == RequestResponse(status="success", data={})
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "patch"])
assert resp_2 == RequestResponse(status="success", data={})
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
assert resp_3 == RequestResponse(status="success", data={})
def test_successful_service_requests(example_network):
net = example_network
server_1 = net.get_node_by_hostname("server_1")
server_1.software_manager.install(TestService)
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
for verb in [
"disable",
"enable",
"start",
"stop",
"start",
"restart",
"pause",
"resume",
"compromise",
"scan",
"patch",
]:
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
assert resp_1 == RequestResponse(status="success", data={})
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)
# lazily apply timestep 7 times to make absolutely sure any time-based things like restart have a chance to finish
def test_non_existent_requests(example_network):
net = example_network
resp_1 = net.apply_request(["fake"])
assert resp_1.status == "unreachable"
resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"])
assert resp_2.status == "unreachable"
@pytest.mark.parametrize(
"node_request",
[
["node", "client_1", "file_system", "folder", "root", "scan"],
["node", "client_1", "os", "scan"],
["node", "client_1", "service", "DNSClient", "stop"],
["node", "client_1", "application", "WebBrowser", "scan"],
["node", "client_1", "network_interface", 1, "disable"],
],
)
def test_request_fails_if_node_off(example_network, node_request):
"""Test that requests succeed when the node is on, and fail if the node is off."""
net = example_network
client_1: HostNode = net.get_node_by_hostname("client_1")
client_1.shut_down_duration = 0
assert client_1.operating_state == NodeOperatingState.ON
resp_1 = net.apply_request(node_request)
assert resp_1.status == "success"
client_1.power_off()
assert client_1.operating_state == NodeOperatingState.OFF
resp_2 = net.apply_request(node_request)
assert resp_2.status == "failure"
class TestDataManipulationGreenRequests:
def test_node_off(self, uc2_network):
"""Test that green requests succeed when the node is on and fail if the node is off."""
net = uc2_network
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"])
assert client_1_browser_execute.status == "success"
assert client_1_db_client_execute.status == "success"
assert client_2_browser_execute.status == "success"
assert client_2_db_client_execute.status == "success"
client_1 = net.get_node_by_hostname("client_1")
client_2 = net.get_node_by_hostname("client_2")
client_1.shut_down_duration = 0
client_1.power_off()
client_2.shut_down_duration = 0
client_2.power_off()
client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
client_1_db_client_execute_off = net.apply_request(
["node", "client_1", "application", "DatabaseClient", "execute"]
)
client_2_browser_execute_off = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
client_2_db_client_execute_off = net.apply_request(
["node", "client_2", "application", "DatabaseClient", "execute"]
)
assert client_1_browser_execute_off.status == "failure"
assert client_1_db_client_execute_off.status == "failure"
assert client_2_browser_execute_off.status == "failure"
assert client_2_db_client_execute_off.status == "failure"
def test_acl_block(self, uc2_network):
"""Test that green requests succeed when not blocked by ACLs but fail when blocked."""
net = uc2_network
router: Router = net.get_node_by_hostname("router_1")
client_1: HostNode = net.get_node_by_hostname("client_1")
client_2: HostNode = net.get_node_by_hostname("client_2")
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
assert client_1_browser_execute.status == "success"
assert client_2_browser_execute.status == "success"
router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
assert client_1_browser_execute.status == "failure"
assert client_2_browser_execute.status == "failure"
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"])
assert client_1_db_client_execute.status == "success"
assert client_2_db_client_execute.status == "success"
router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"])
assert client_1_db_client_execute.status == "failure"
assert client_2_db_client_execute.status == "failure"

View File

@@ -0,0 +1,32 @@
import pytest
from pydantic import ValidationError
from primaite.interface.request import RequestResponse
def test_creating_response_object():
"""Test that we can create a response object with given parameters."""
r1 = RequestResponse(status="success", data={"test_data": 1, "other_data": 2})
r2 = RequestResponse(status="unreachable")
r3 = RequestResponse(data={"test_data": "is_good"})
r4 = RequestResponse()
assert isinstance(r1, RequestResponse)
assert isinstance(r2, RequestResponse)
assert isinstance(r3, RequestResponse)
assert isinstance(r4, RequestResponse)
def test_creating_response_from_boolean():
"""Test that we can build a response with a single boolean."""
r1 = RequestResponse.from_bool(True)
assert r1.status == "success"
r2 = RequestResponse.from_bool(False)
assert r2.status == "failure"
with pytest.raises(ValidationError):
r3 = RequestResponse.from_bool(1)
with pytest.raises(ValidationError):
r4 = RequestResponse.from_bool("good")
with pytest.raises(ValidationError):
r5 = RequestResponse.from_bool({"data": True})