Merge remote-tracking branch 'origin/dev' into feature/2350-confirm-action-observation-space-conforms-to-CAOS-0.7
This commit is contained in:
@@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Made requests fail to reach their target if the node is off
|
||||
- Added responses to requests
|
||||
- Made environment reset completely recreate the game object.
|
||||
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
|
||||
- Changed the data manipulation scenario to include a second green agent on client 1.
|
||||
|
||||
@@ -6,7 +6,7 @@ SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
AUTOSUMMARY="source\_autosummary"
|
||||
AUTOSUMMARY="source/_autosummary"
|
||||
|
||||
# Remove command is different depending on OS
|
||||
ifdef OS
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
Request System
|
||||
==============
|
||||
**************
|
||||
|
||||
``SimComponent`` objects in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``RequestType``.
|
||||
|
||||
@@ -12,26 +12,37 @@ Just like other aspects of SimComponent, the request types are not managed centr
|
||||
- API
|
||||
When requesting an action within the simulation, these two arguments must be provided:
|
||||
|
||||
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '<node-name>', 'service', '<service-name>', 'restart']`.
|
||||
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as ``['network', 'node', '<node-name>', 'service', '<service-name>', 'restart']``.
|
||||
2. ``context`` - optional extra information that can be used to decide how to process the request. This is formatted as a dictionary. For example, if the request requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient.
|
||||
|
||||
When a request is resolved, it returns a success status, and optional additional data about the request.
|
||||
|
||||
``status`` can be one of:
|
||||
|
||||
* ``success``: the request was executed
|
||||
* ``failure``: the request could not be executed
|
||||
* ``unreachable``: the target for the request was not found
|
||||
* ``pending``: the request was initiated, but has not finished during this step
|
||||
|
||||
``data`` can be a dictionary with any arbitrary JSON-like data to describe the outcome of the request.
|
||||
|
||||
- ``request`` detail
|
||||
The request is a list of strings which help specify who should handle the request. The strings in the request list help RequestManagers traverse the 'ownership tree' of SimComponent. The example given above would be handled in the following way:
|
||||
|
||||
1. ``Simulation`` receives `['network', 'node', '<node-name>', 'service', '<service-name>', 'restart']`.
|
||||
1. ``Simulation`` receives ``['network', 'node', 'computer_1', 'service', 'DNSService', 'restart']``.
|
||||
The first element of the request is ``network``, therefore it passes the request down to its network.
|
||||
2. ``Network`` receives `['node', '<node-name>', 'service', '<service-name>', 'restart']`.
|
||||
2. ``Network`` receives ``['node', 'computer_1', 'service', 'DNSService', 'restart']``.
|
||||
The first element of the request is ``node``, therefore the network looks at the node name and passes the request down to the node with that name.
|
||||
3. ``Node`` receives `['service', '<service-name>', 'restart']`.
|
||||
3. ``computer_1`` receives ``['service', 'DNSService', 'restart']``.
|
||||
The first element of the request is ``service``, therefore the node looks at the service name and passes the rest of the request to the service with that name.
|
||||
4. ``Service`` receives ``['restart']``.
|
||||
4. ``DNSService`` receives ``['restart']``.
|
||||
Since ``restart`` is a defined request type in the service's own RequestManager, the service performs a restart.
|
||||
|
||||
- ``context`` detail
|
||||
The context is not used by any of the currently implemented components or requests.
|
||||
|
||||
Technical Detail
|
||||
----------------
|
||||
================
|
||||
|
||||
This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.RequestType`, and :py:class:`primaite.simulator.core.RequestManager`.
|
||||
|
||||
@@ -93,3 +104,19 @@ An example of how this works is in the :py:class:`primaite.simulator.network.har
|
||||
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
|
||||
|
||||
This process is repeated until the request word corresponds to a callable function rather than another ``RequestManager`` .
|
||||
|
||||
Request Validation
|
||||
------------------
|
||||
|
||||
There are times when a request should be rejected. For instance, if an agent attempts to run an application on a node that is currently off. For this purpose, requests are filtered by an object called a validator. :py:class:`primaite.simulator.core.RequestPermissionValidator` is a basic class whose ``__call__()`` method returns ``True`` if the request should be permitted or ``False`` if it cannot be permitted. For example, the Node class has a validator called :py:class:`primaite.simulator.network.hardware.base.Node._NodeIsOnValidator<_NodeIsOnValidator>` which allows requests only when the operating status of the node is ``ON``.
|
||||
|
||||
Requests that are specified without a validator automatically get assigned an ``AllowAllValidator`` which allows requests no matter what.
|
||||
|
||||
Request Response
|
||||
----------------
|
||||
|
||||
The :py:class:`primaite.interface.request.RequestResponse<RequestResponse>` is a data transfer object that carries response data between the simulator and the game layer. The ``status`` field reports on the success or failure, and the ``data`` field is for any additional data. The most common way that this class is initiated is by its ``from_bool`` method. This way, given a True or False, a successful or failed request response is generated, respectively (with an empty data field).
|
||||
|
||||
For instance, the ``execute`` action on a :py:class:`primaite.simulator.system.applications.web_browser.WebBrowser<WebBrowser>` calls the ``get_webpage()`` method of the ``WebBrowser``. ``get_webpage()`` returns a True if the webpage was successfully retrieved, and False if unsuccessful for any reason, such as being blocked by an ACL, or if the database server is unresponsive. The boolean returned from ``get_webpage()`` is used to create the request response.
|
||||
|
||||
Just as the requests themselves were passed from owner to component, the request response is bubbled back up from component to owner until it arrives at the game layer.
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
Simulation State
|
||||
================
|
||||
|
||||
``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the childrens' own ``describe_state`` methods.
|
||||
``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the children's own ``describe_state`` methods.
|
||||
|
||||
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objetcs must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
|
||||
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objects must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
|
||||
|
||||
This code snippet demonstrates how the state information is defined within the ``SimComponent`` class:
|
||||
|
||||
|
||||
@@ -492,9 +492,9 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
"add_rule",
|
||||
permission_str,
|
||||
protocol,
|
||||
src_ip,
|
||||
str(src_ip),
|
||||
src_port,
|
||||
dst_ip,
|
||||
str(dst_ip),
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
|
||||
@@ -170,9 +170,13 @@ class PrimaiteGame:
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent_actions[agent.agent_name] = {
|
||||
"action": action_choice,
|
||||
"parameters": options,
|
||||
"response": response.model_dump(),
|
||||
}
|
||||
return agent_actions
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
|
||||
0
src/primaite/interface/__init__.py
Normal file
0
src/primaite/interface/__init__.py
Normal file
44
src/primaite/interface/request.py
Normal file
44
src/primaite/interface/request.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
|
||||
class RequestResponse(BaseModel):
|
||||
"""Schema for generic request responses."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
"""Cannot have extra fields in the response. Anything custom goes into the data field."""
|
||||
|
||||
status: Literal["pending", "success", "failure", "unreachable"] = "pending"
|
||||
"""
|
||||
What is the current status of the request:
|
||||
- pending - the request has not been received yet, or it has been received but it's still being processed.
|
||||
- success - the request has been received and executed successfully.
|
||||
- failure - the request has been received and attempted, but execution failed.
|
||||
- unreachable - the request could not reach it's intended target, either because it doesn't exist or the target
|
||||
is off.
|
||||
"""
|
||||
|
||||
data: Dict = {}
|
||||
"""Catch-all place to provide any additional data that was generated as a response to the request."""
|
||||
# TODO: currently, status and data have default values, because I don't want to interrupt existing functionality too
|
||||
# much. However, in the future we might consider making them mandatory.
|
||||
|
||||
@classmethod
|
||||
@validate_call
|
||||
def from_bool(cls, status_bool: StrictBool) -> RequestResponse:
|
||||
"""
|
||||
Construct a basic request response from a boolean.
|
||||
|
||||
True maps to a success status. False maps to a failure status.
|
||||
|
||||
:param status_bool: Whether to create a successful response
|
||||
:type status_bool: bool
|
||||
"""
|
||||
if status_bool is True:
|
||||
return cls(status="success", data={})
|
||||
elif status_bool is False:
|
||||
return cls(status="failure", data={})
|
||||
@@ -426,13 +426,13 @@
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info[0]\n",
|
||||
" red_action = red_info['action']\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
|
||||
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
" return red_str\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -492,7 +492,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent."
|
||||
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) increased from 0 to 1, but only right after the red attack, so we probably cannot see it now."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -510,9 +510,9 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
|
||||
"print(f\"Blue reward:{reward}\" )"
|
||||
]
|
||||
},
|
||||
@@ -535,7 +535,7 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
|
||||
"print(f\"Blue reward:{reward:.2f}\" )"
|
||||
@@ -557,17 +557,17 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(50) # Block client 1\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(51) # Block client 2\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"for step in range(30):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -606,20 +606,35 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's unblock client 2\n",
|
||||
" env.step(58) # remove ACL rule 6\n",
|
||||
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" env.step(57) # remove ACL rule 5\n",
|
||||
"else:\n",
|
||||
" print(\"something went wrong, neither client has NMNEs\")"
|
||||
"env.step(58) # Remove the ACL rule that blocks client 1\n",
|
||||
"env.step(57) # Remove the ACL rule that blocks client 2\n",
|
||||
"\n",
|
||||
"tries = 0\n",
|
||||
"while True:\n",
|
||||
" tries += 1\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
"\n",
|
||||
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
|
||||
" break\n",
|
||||
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 2 has NMNEs, so let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
|
||||
" break\n",
|
||||
" if tries>100:\n",
|
||||
" print(\"Error: NMNE never increased\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"env.step(13) # Patch the database\n",
|
||||
"..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks."
|
||||
"Now, the reward will eventually increase to 0.9, even after red agent attempts subsequent attacks."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -628,9 +643,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for step in range(30):\n",
|
||||
"\n",
|
||||
"for step in range(40):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -648,6 +664,13 @@
|
||||
"source": [
|
||||
"env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -93,7 +93,7 @@ class PrimaiteIO:
|
||||
{
|
||||
"episode": episode,
|
||||
"timestep": timestep,
|
||||
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
|
||||
"agent_actions": agent_actions,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
# flake8: noqa
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
|
||||
class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
@@ -21,15 +24,15 @@ class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, request: List[str], context: Dict) -> bool:
|
||||
"""Use the request and context paramters to decide whether the request should be permitted."""
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Use the request and context parameters to decide whether the request should be permitted."""
|
||||
pass
|
||||
|
||||
|
||||
class AllowAllValidator(RequestPermissionValidator):
|
||||
"""Always allows the request."""
|
||||
|
||||
def __call__(self, request: List[str], context: Dict) -> bool:
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Always allow the request."""
|
||||
return True
|
||||
|
||||
@@ -42,7 +45,7 @@ class RequestType(BaseModel):
|
||||
the request can be performed or not.
|
||||
"""
|
||||
|
||||
func: Callable[[List[str], Dict], None]
|
||||
func: Callable[[RequestFormat, Dict], RequestResponse]
|
||||
"""
|
||||
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
|
||||
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
|
||||
@@ -71,7 +74,7 @@ class RequestManager(BaseModel):
|
||||
request_types: Dict[str, RequestType] = {}
|
||||
"""maps request name to an RequestType object."""
|
||||
|
||||
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Process an request request.
|
||||
|
||||
@@ -84,23 +87,23 @@ class RequestManager(BaseModel):
|
||||
:raises RuntimeError: If the request parameter does not have a valid request name as the first item.
|
||||
"""
|
||||
request_key = request[0]
|
||||
request_options = request[1:]
|
||||
|
||||
if request_key not in self.request_types:
|
||||
msg = (
|
||||
f"Request {request} could not be processed because {request_key} is not a valid request name",
|
||||
"within this RequestManager",
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
_LOGGER.debug(msg)
|
||||
return RequestResponse(status="unreachable", data={"reason": msg})
|
||||
|
||||
request_type = self.request_types[request_key]
|
||||
request_options = request[1:]
|
||||
|
||||
if not request_type.validator(request_options, context):
|
||||
_LOGGER.debug(f"Request {request} was denied due to insufficient permissions")
|
||||
return
|
||||
return RequestResponse(status="failure", data={"reason": "request validation failed"})
|
||||
|
||||
request_type.func(request_options, context)
|
||||
return request_type.func(request_options, context)
|
||||
|
||||
def add_request(self, name: str, request_type: RequestType) -> None:
|
||||
"""
|
||||
@@ -202,7 +205,8 @@ class SimComponent(BaseModel):
|
||||
}
|
||||
return state
|
||||
|
||||
def apply_request(self, request: List[str], context: Dict = {}) -> None:
|
||||
@validate_call
|
||||
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
|
||||
"""
|
||||
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
|
||||
|
||||
@@ -222,7 +226,7 @@ class SimComponent(BaseModel):
|
||||
"""
|
||||
if self._request_manager is None:
|
||||
return
|
||||
self._request_manager(request, context)
|
||||
return self._request_manager(request, context)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -80,6 +80,11 @@ class DomainController(SimComponent):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
# Action 'account' matches requests like:
|
||||
# ['account', '<account-uuid>', *account_action]
|
||||
@@ -87,6 +92,7 @@ class DomainController(SimComponent):
|
||||
"account",
|
||||
RequestType(
|
||||
func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context),
|
||||
# TODO: not sure what should get returned here, revisit
|
||||
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -114,16 +114,17 @@ class File(FileSystemItemABC):
|
||||
state["num_access"] = self.num_access
|
||||
return state
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Updates the visible statuses of the file."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
|
||||
self.visible_health_status = self.health_status
|
||||
return True
|
||||
|
||||
def reveal_to_red(self) -> None:
|
||||
"""Reveals the folder/file to the red agent."""
|
||||
@@ -132,7 +133,7 @@ class File(FileSystemItemABC):
|
||||
return
|
||||
self.revealed_to_red = True
|
||||
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Check if the file has been changed.
|
||||
|
||||
@@ -142,7 +143,7 @@ class File(FileSystemItemABC):
|
||||
"""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to check hash of deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
current_hash = None
|
||||
|
||||
# if file is real, read the file contents
|
||||
@@ -164,12 +165,13 @@ class File(FileSystemItemABC):
|
||||
# if the previous hash and current hash do not match, mark file as corrupted
|
||||
if self.previous_hash is not current_hash:
|
||||
self.corrupt()
|
||||
return True
|
||||
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to repair deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# set file status to good if corrupt
|
||||
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
@@ -178,12 +180,13 @@ class File(FileSystemItemABC):
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to corrupt deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# set file status to good if corrupt
|
||||
if self.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
@@ -192,12 +195,13 @@ class File(FileSystemItemABC):
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""Determines if the file needs to be repaired or unmarked as deleted."""
|
||||
if self.deleted:
|
||||
self.deleted = False
|
||||
return
|
||||
return True
|
||||
|
||||
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.health_status = FileSystemItemHealthStatus.GOOD
|
||||
@@ -205,13 +209,15 @@ class File(FileSystemItemABC):
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> bool:
|
||||
"""Marks the file as deleted."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
self.deleted = True
|
||||
self.sys_log.info(f"File deleted {self.folder_name}/{self.name}")
|
||||
return True
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
@@ -39,18 +40,27 @@ class FileSystem(SimComponent):
|
||||
self.create_folder("root")
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
self._delete_manager = RequestManager()
|
||||
self._delete_manager.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: self.delete_file(folder_name=request[0], file_name=request[1])
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.delete_file(folder_name=request[0], file_name=request[1])
|
||||
)
|
||||
),
|
||||
)
|
||||
self._delete_manager.add_request(
|
||||
name="folder",
|
||||
request_type=RequestType(func=lambda request, context: self.delete_folder(folder_name=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
name="delete",
|
||||
@@ -61,12 +71,16 @@ class FileSystem(SimComponent):
|
||||
self._restore_manager.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: self.restore_file(folder_name=request[0], file_name=request[1])
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.restore_file(folder_name=request[0], file_name=request[1])
|
||||
)
|
||||
),
|
||||
)
|
||||
self._restore_manager.add_request(
|
||||
name="folder",
|
||||
request_type=RequestType(func=lambda request, context: self.restore_folder(folder_name=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.restore_folder(folder_name=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
name="restore",
|
||||
@@ -142,7 +156,7 @@ class FileSystem(SimComponent):
|
||||
)
|
||||
return folder
|
||||
|
||||
def delete_folder(self, folder_name: str):
|
||||
def delete_folder(self, folder_name: str) -> bool:
|
||||
"""
|
||||
Deletes a folder, removes it from the folders list and removes any child folders and files.
|
||||
|
||||
@@ -150,24 +164,26 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
return False
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
else:
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
|
||||
return False
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str):
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
return True
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str) -> None:
|
||||
"""
|
||||
Deletes a folder via its uuid.
|
||||
|
||||
@@ -303,7 +319,7 @@ class FileSystem(SimComponent):
|
||||
|
||||
return file
|
||||
|
||||
def delete_file(self, folder_name: str, file_name: str):
|
||||
def delete_file(self, folder_name: str, file_name: str) -> bool:
|
||||
"""
|
||||
Delete a file by its name from a specific folder.
|
||||
|
||||
@@ -317,8 +333,10 @@ class FileSystem(SimComponent):
|
||||
# increment file creation
|
||||
self.num_file_deletions += 1
|
||||
folder.remove_file(file)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_file_by_id(self, folder_uuid: str, file_uuid: str):
|
||||
def delete_file_by_id(self, folder_uuid: str, file_uuid: str) -> None:
|
||||
"""
|
||||
Deletes a file via its uuid.
|
||||
|
||||
@@ -335,7 +353,7 @@ class FileSystem(SimComponent):
|
||||
else:
|
||||
self.sys_log.error(f"Unable to delete file that does not exist. (id: {file_uuid})")
|
||||
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str) -> None:
|
||||
"""
|
||||
Move a file from one folder to another.
|
||||
|
||||
@@ -423,7 +441,7 @@ class FileSystem(SimComponent):
|
||||
# Agent actions
|
||||
###############################################################
|
||||
|
||||
def scan(self, instant_scan: bool = False):
|
||||
def scan(self, instant_scan: bool = False) -> None:
|
||||
"""
|
||||
Scan all the folders (and child files) in the file system.
|
||||
|
||||
@@ -432,7 +450,7 @@ class FileSystem(SimComponent):
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].scan(instant_scan=instant_scan)
|
||||
|
||||
def reveal_to_red(self, instant_scan: bool = False):
|
||||
def reveal_to_red(self, instant_scan: bool = False) -> None:
|
||||
"""
|
||||
Reveals all the folders (and child files) in the file system to the red agent.
|
||||
|
||||
@@ -441,7 +459,7 @@ class FileSystem(SimComponent):
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].reveal_to_red(instant_scan=instant_scan)
|
||||
|
||||
def restore_folder(self, folder_name: str):
|
||||
def restore_folder(self, folder_name: str) -> bool:
|
||||
"""
|
||||
Restore a folder.
|
||||
|
||||
@@ -454,13 +472,14 @@ class FileSystem(SimComponent):
|
||||
|
||||
if folder is None:
|
||||
self.sys_log.error(f"Unable to restore folder {folder_name}. Folder is not in deleted folder list.")
|
||||
return
|
||||
return False
|
||||
|
||||
self.deleted_folders.pop(folder.uuid, None)
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
return True
|
||||
|
||||
def restore_file(self, folder_name: str, file_name: str):
|
||||
def restore_file(self, folder_name: str, file_name: str) -> bool:
|
||||
"""
|
||||
Restore a file.
|
||||
|
||||
@@ -473,12 +492,15 @@ class FileSystem(SimComponent):
|
||||
:type: file_name: str
|
||||
"""
|
||||
folder = self.get_folder(folder_name=folder_name)
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.")
|
||||
return False
|
||||
|
||||
if folder:
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
|
||||
if file is None:
|
||||
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
|
||||
return
|
||||
if not file:
|
||||
msg = f"Unable to restore file {file_name}. File was not found."
|
||||
self.sys_log.error(msg)
|
||||
return False
|
||||
|
||||
folder.restore_file(file_name=file_name)
|
||||
return folder.restore_file(file_name=file_name)
|
||||
|
||||
@@ -6,6 +6,7 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
@@ -100,14 +101,33 @@ class FileSystemItemABC(SimComponent):
|
||||
return state
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash()))
|
||||
rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair()))
|
||||
rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore()))
|
||||
rm.add_request(
|
||||
name="scan", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
|
||||
)
|
||||
rm.add_request(
|
||||
name="checkhash",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.check_hash())),
|
||||
)
|
||||
rm.add_request(
|
||||
name="repair",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.repair())),
|
||||
)
|
||||
rm.add_request(
|
||||
name="restore",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.restore())),
|
||||
)
|
||||
|
||||
rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt()))
|
||||
rm.add_request(
|
||||
name="corrupt",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.corrupt())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -124,9 +144,9 @@ class FileSystemItemABC(SimComponent):
|
||||
return convert_size(self.size)
|
||||
|
||||
@abstractmethod
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Scan the folder/file - updates the visible_health_status."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def reveal_to_red(self) -> None:
|
||||
@@ -134,7 +154,7 @@ class FileSystemItemABC(SimComponent):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Checks the has of the file to detect any changes.
|
||||
|
||||
@@ -142,30 +162,30 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""
|
||||
Repair the FileSystemItem.
|
||||
|
||||
True if successfully repaired. False otherwise.
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""
|
||||
Corrupt the FileSystemItem.
|
||||
|
||||
True if successfully corrupted. False otherwise.
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""Restore the file/folder to the state before it got ruined."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
@@ -50,10 +51,17 @@ class Folder(FileSystemItemABC):
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="delete",
|
||||
request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.remove_file_by_name(file_name=request[0]))
|
||||
),
|
||||
)
|
||||
self._file_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
@@ -250,6 +258,21 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file_by_id(file_uuid=file_uuid)
|
||||
self.remove_file(file=file)
|
||||
|
||||
def remove_file_by_name(self, file_name: str) -> bool:
|
||||
"""
|
||||
Remove a file using its name.
|
||||
|
||||
:param file_name: filename
|
||||
:type file_name: str
|
||||
:return: Whether it was successfully removed.
|
||||
:rtype: bool
|
||||
"""
|
||||
for f in self.files.values():
|
||||
if f.name == file_name:
|
||||
self.remove_file(f)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_all_files(self):
|
||||
"""Removes all the files in the folder."""
|
||||
for file_id in self.files:
|
||||
@@ -259,7 +282,7 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
self.files = {}
|
||||
|
||||
def restore_file(self, file_name: str):
|
||||
def restore_file(self, file_name: str) -> bool:
|
||||
"""
|
||||
Restores a file.
|
||||
|
||||
@@ -269,13 +292,14 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file(file_name=file_name, include_deleted=True)
|
||||
if not file:
|
||||
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
|
||||
return
|
||||
return False
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file.uuid)
|
||||
return True
|
||||
|
||||
def quarantine(self):
|
||||
"""Quarantines the File System Folder."""
|
||||
@@ -289,7 +313,7 @@ class Folder(FileSystemItemABC):
|
||||
"""Returns true if the folder is being quarantined."""
|
||||
pass
|
||||
|
||||
def scan(self, instant_scan: bool = False) -> None:
|
||||
def scan(self, instant_scan: bool = False) -> bool:
|
||||
"""
|
||||
Update Folder visible status.
|
||||
|
||||
@@ -297,7 +321,7 @@ class Folder(FileSystemItemABC):
|
||||
"""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to scan deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
if instant_scan:
|
||||
for file_id in self.files:
|
||||
@@ -305,7 +329,7 @@ class Folder(FileSystemItemABC):
|
||||
file.scan()
|
||||
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
return
|
||||
return True
|
||||
|
||||
if self.scan_countdown <= 0:
|
||||
# scan one file per timestep
|
||||
@@ -314,6 +338,7 @@ class Folder(FileSystemItemABC):
|
||||
else:
|
||||
# scan already in progress
|
||||
self.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def reveal_to_red(self, instant_scan: bool = False):
|
||||
"""
|
||||
@@ -340,7 +365,7 @@ class Folder(FileSystemItemABC):
|
||||
# scan already in progress
|
||||
self.sys_log.info(f"Red Agent Scan is already in progress {self.name} (id: {self.uuid})")
|
||||
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Runs a :func:`check_hash` on all files in the folder.
|
||||
|
||||
@@ -353,7 +378,7 @@ class Folder(FileSystemItemABC):
|
||||
"""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to check hash of deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# iterate through the files and run a check hash
|
||||
no_corrupted_files = True
|
||||
@@ -369,12 +394,13 @@ class Folder(FileSystemItemABC):
|
||||
self.corrupt()
|
||||
|
||||
self.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to repair deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# iterate through the files in the folder
|
||||
for file_id in self.files:
|
||||
@@ -388,8 +414,9 @@ class Folder(FileSystemItemABC):
|
||||
self.health_status = FileSystemItemHealthStatus.GOOD
|
||||
|
||||
self.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""
|
||||
If a Folder is corrupted, run a repair on the folder and its child files.
|
||||
|
||||
@@ -405,12 +432,13 @@ class Folder(FileSystemItemABC):
|
||||
else:
|
||||
# scan already in progress
|
||||
self.sys_log.info(f"Folder restoration already in progress {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to corrupt deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# iterate through the files in the folder
|
||||
for file_id in self.files:
|
||||
@@ -421,11 +449,13 @@ class Folder(FileSystemItemABC):
|
||||
self.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
self.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> bool:
|
||||
"""Marks the file as deleted. Prevents agent actions from occuring."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to delete an already deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
self.deleted = True
|
||||
return True
|
||||
|
||||
@@ -61,6 +61,11 @@ class Network(SimComponent):
|
||||
software.run()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
self._node_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
|
||||
@@ -12,8 +12,9 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -113,10 +114,15 @@ class NetworkInterface(SimComponent, ABC):
|
||||
self.enable()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -140,14 +146,16 @@ class NetworkInterface(SimComponent, ABC):
|
||||
return state
|
||||
|
||||
@abstractmethod
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return False
|
||||
|
||||
def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None:
|
||||
"""
|
||||
@@ -257,10 +265,9 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
|
||||
This just clears the nmne count back to 0.tests/integration_tests/network/test_capture_nmne.py
|
||||
This just clears the nmne count back to 0.
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
self.nmne.clear()
|
||||
|
||||
|
||||
class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
@@ -284,24 +291,24 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
_connected_link: Optional[Link] = None
|
||||
"The network link to which the network interface is connected."
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Attempt to enable the network interface."""
|
||||
if self.enabled:
|
||||
return
|
||||
return True
|
||||
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
|
||||
return
|
||||
return False
|
||||
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.info(
|
||||
f"Interface {self} cannot be enabled as the connected Node is not powered on"
|
||||
)
|
||||
return
|
||||
return False
|
||||
|
||||
if not self._connected_link:
|
||||
self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.")
|
||||
return
|
||||
return False
|
||||
|
||||
self.enabled = True
|
||||
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
|
||||
@@ -310,11 +317,12 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
)
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_up()
|
||||
return True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the network interface."""
|
||||
if not self.enabled:
|
||||
return
|
||||
return True
|
||||
self.enabled = False
|
||||
if self._connected_node:
|
||||
self._connected_node.sys_log.info(f"Network Interface {self} disabled")
|
||||
@@ -322,6 +330,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
_LOGGER.debug(f"Interface {self} disabled")
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_down()
|
||||
return True
|
||||
|
||||
def connect_link(self, link: Link):
|
||||
"""
|
||||
@@ -496,7 +505,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
|
||||
return state
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""
|
||||
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
|
||||
|
||||
@@ -512,8 +521,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
try:
|
||||
pass
|
||||
self._connected_node.default_gateway_hello()
|
||||
return True
|
||||
except AttributeError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
@@ -765,35 +776,74 @@ class Node(SimComponent):
|
||||
self.sys_log.current_episode = episode
|
||||
self.sys_log.setup_logger()
|
||||
|
||||
class _NodeIsOnValidator(RequestPermissionValidator):
|
||||
"""
|
||||
When requests come in, this validator will only let them through if the node is on.
|
||||
|
||||
This is useful because no actions should be being resolved if the node is off.
|
||||
"""
|
||||
|
||||
node: Node
|
||||
"""Save a reference to the node instance."""
|
||||
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Return whether the node is on or off."""
|
||||
return self.node.operating_state == NodeOperatingState.ON
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
|
||||
# need a better name and better documentation.
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
_node_is_on = Node._NodeIsOnValidator(node=self)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
# since there are potentially many services, create an request manager that can map service name
|
||||
self._service_request_manager = RequestManager()
|
||||
rm.add_request("service", RequestType(func=self._service_request_manager))
|
||||
rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on))
|
||||
self._nic_request_manager = RequestManager()
|
||||
rm.add_request("network_interface", RequestType(func=self._nic_request_manager))
|
||||
rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on))
|
||||
|
||||
rm.add_request("file_system", RequestType(func=self.file_system._request_manager))
|
||||
rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on))
|
||||
|
||||
# currently we don't have any applications nor processes, so these will be empty
|
||||
self._process_request_manager = RequestManager()
|
||||
rm.add_request("process", RequestType(func=self._process_request_manager))
|
||||
rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on))
|
||||
self._application_request_manager = RequestManager()
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager))
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on))
|
||||
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.reveal_to_red()))
|
||||
rm.add_request(
|
||||
"scan",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
|
||||
rm.add_request("reset", RequestType(func=lambda request, context: self.reset())) # TODO implement node reset
|
||||
rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
|
||||
rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
|
||||
rm.add_request(
|
||||
"shutdown",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
|
||||
),
|
||||
)
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
|
||||
rm.add_request(
|
||||
"reset",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),
|
||||
) # TODO implement node reset
|
||||
rm.add_request(
|
||||
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
|
||||
) # TODO implement logon request
|
||||
rm.add_request(
|
||||
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
|
||||
) # TODO implement logoff request
|
||||
|
||||
self._os_request_manager = RequestManager()
|
||||
self._os_request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager))
|
||||
self._os_request_manager.add_request(
|
||||
"scan",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on),
|
||||
)
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -973,7 +1023,7 @@ class Node(SimComponent):
|
||||
|
||||
self.file_system.apply_timestep(timestep=timestep)
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""
|
||||
Scan the node and all the items within it.
|
||||
|
||||
@@ -987,8 +1037,9 @@ class Node(SimComponent):
|
||||
to the red agent.
|
||||
"""
|
||||
self.node_scan_countdown = self.node_scan_duration
|
||||
return True
|
||||
|
||||
def reveal_to_red(self) -> None:
|
||||
def reveal_to_red(self) -> bool:
|
||||
"""
|
||||
Reveals the node and all the items within it to the red agent.
|
||||
|
||||
@@ -1002,34 +1053,40 @@ class Node(SimComponent):
|
||||
`revealed_to_red` to `True`.
|
||||
"""
|
||||
self.red_scan_countdown = self.node_scan_duration
|
||||
return True
|
||||
|
||||
def power_on(self):
|
||||
def power_on(self) -> bool:
|
||||
"""Power on the Node, enabling its NICs if it is in the OFF state."""
|
||||
if self.operating_state == NodeOperatingState.OFF:
|
||||
self.operating_state = NodeOperatingState.BOOTING
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
|
||||
if self.start_up_duration <= 0:
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self._start_up_actions()
|
||||
self.sys_log.info("Power on")
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.enable()
|
||||
return True
|
||||
if self.operating_state == NodeOperatingState.OFF:
|
||||
self.operating_state = NodeOperatingState.BOOTING
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
return True
|
||||
|
||||
def power_off(self):
|
||||
return False
|
||||
|
||||
def power_off(self) -> bool:
|
||||
"""Power off the Node, disabling its NICs if it is in the ON state."""
|
||||
if self.shut_down_duration <= 0:
|
||||
self._shut_down_actions()
|
||||
self.operating_state = NodeOperatingState.OFF
|
||||
self.sys_log.info("Power off")
|
||||
return True
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.disable()
|
||||
self.operating_state = NodeOperatingState.SHUTTING_DOWN
|
||||
self.shut_down_countdown = self.shut_down_duration
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.shut_down_duration <= 0:
|
||||
self._shut_down_actions()
|
||||
self.operating_state = NodeOperatingState.OFF
|
||||
self.sys_log.info("Power off")
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> bool:
|
||||
"""
|
||||
Resets the node.
|
||||
|
||||
@@ -1040,6 +1097,8 @@ class Node(SimComponent):
|
||||
self.is_resetting = True
|
||||
self.sys_log.info("Resetting")
|
||||
self.power_off()
|
||||
return True
|
||||
return False
|
||||
|
||||
def connect_nic(self, network_interface: NetworkInterface, port_name: Optional[str] = None):
|
||||
"""
|
||||
|
||||
@@ -51,13 +51,15 @@ class WirelessAccessPoint(IPWirelessNetworkInterface):
|
||||
|
||||
return state
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
|
||||
@@ -48,13 +48,15 @@ class WirelessNIC(IPWirelessNetworkInterface):
|
||||
|
||||
return state
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -293,6 +294,11 @@ class AccessControlList(SimComponent):
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
# TODO: Add src and dst wildcard masks as positional args in this request.
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -308,19 +314,24 @@ class AccessControlList(SimComponent):
|
||||
rm.add_request(
|
||||
"add_rule",
|
||||
RequestType(
|
||||
func=lambda request, context: self.add_rule(
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.add_rule(
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
|
||||
rm.add_request(
|
||||
"remove_rule",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.remove_rule(int(request[0])))),
|
||||
)
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
@@ -366,7 +377,7 @@ class AccessControlList(SimComponent):
|
||||
src_port: Optional[Port] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
position: int = 0,
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""
|
||||
Adds a new ACL rule to control network traffic based on specified criteria.
|
||||
|
||||
@@ -423,10 +434,12 @@ class AccessControlList(SimComponent):
|
||||
src_port=src_port,
|
||||
dst_port=dst_port,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.")
|
||||
return False
|
||||
|
||||
def remove_rule(self, position: int) -> None:
|
||||
def remove_rule(self, position: int) -> bool:
|
||||
"""
|
||||
Remove an ACL rule from a specific position.
|
||||
|
||||
@@ -437,8 +450,10 @@ class AccessControlList(SimComponent):
|
||||
rule = self._acl[position] # noqa
|
||||
self._acl[position] = None
|
||||
del rule
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.")
|
||||
return False
|
||||
|
||||
def is_permitted(self, frame: Frame) -> Tuple[bool, ACLRule]:
|
||||
"""Check if a packet with the given properties is permitted through the ACL."""
|
||||
@@ -1082,6 +1097,11 @@ class Router(NetworkNode):
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("acl", RequestType(func=self.acl._request_manager))
|
||||
return rm
|
||||
|
||||
@@ -146,9 +146,12 @@ def arcd_uc2_network() -> Network:
|
||||
)
|
||||
client_1.power_on()
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
db_client_1 = client_1.software_manager.install(DatabaseClient)
|
||||
db_client_1 = client_1.software_manager.software.get("DatabaseClient")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
db_client_1.run()
|
||||
web_browser_1 = client_1.software_manager.software.get("WebBrowser")
|
||||
web_browser_1.target_url = "http://arcd.com/users/"
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
db_manipulation_bot.configure(
|
||||
@@ -170,9 +173,10 @@ def arcd_uc2_network() -> Network:
|
||||
client_2.power_on()
|
||||
client_2.software_manager.install(DatabaseClient)
|
||||
db_client_2 = client_2.software_manager.software.get("DatabaseClient")
|
||||
db_client_2.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
db_client_2.run()
|
||||
web_browser = client_2.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
web_browser_2 = client_2.software_manager.software.get("WebBrowser")
|
||||
web_browser_2.target_url = "http://arcd.com/users/"
|
||||
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
|
||||
|
||||
# Domain Controller
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Dict
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.domain.controller import DomainController
|
||||
from primaite.simulator.network.container import Network
|
||||
@@ -26,12 +27,18 @@ class Simulation(SimComponent):
|
||||
self.network.setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
# pass through network requests to the network objects
|
||||
rm.add_request("network", RequestType(func=self.network._request_manager))
|
||||
# pass through domain requests to the domain object
|
||||
rm.add_request("domain", RequestType(func=self.domain._request_manager))
|
||||
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
|
||||
# if 'do_nothing' is requested, just return a success
|
||||
rm.add_request("do_nothing", RequestType(func=lambda request, context: RequestResponse(status="success")))
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -36,8 +37,13 @@ class DatabaseClient(Application):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: self.execute()))
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
|
||||
return rm
|
||||
|
||||
def execute(self) -> bool:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -74,9 +75,17 @@ class DataManipulationBot(Application):
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -179,23 +188,22 @@ class DataManipulationBot(Application):
|
||||
"""
|
||||
super().run()
|
||||
|
||||
def attack(self):
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
self._application_loop()
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop of the bot, handling the attack process.
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
self.num_executions += 1
|
||||
self.num_executions += 1
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self._logon()
|
||||
@@ -207,8 +215,12 @@ class DataManipulationBot(Application):
|
||||
DataManipulationAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = DataManipulationAttackStage.NOT_STARTED
|
||||
|
||||
return True
|
||||
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -57,9 +58,17 @@ class DoSBot(DatabaseClient):
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -97,26 +106,26 @@ class DoSBot(DatabaseClient):
|
||||
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
|
||||
)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> bool:
|
||||
"""Run the Denial of Service Bot."""
|
||||
super().run()
|
||||
self._application_loop()
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
return False
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return
|
||||
return True
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
@@ -126,6 +135,7 @@ class DoSBot(DatabaseClient):
|
||||
self.attack_stage = DoSAttackStage.NOT_STARTED
|
||||
else:
|
||||
self.attack_stage = DoSAttackStage.COMPLETED
|
||||
return True
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from urllib.parse import urlparse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
@@ -50,9 +51,17 @@ class WebBrowser(Application):
|
||||
self.run()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa
|
||||
name="execute",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.get_webpage())
|
||||
), # noqa
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
@@ -79,15 +80,20 @@ class Service(IOSoftware):
|
||||
return super().receive(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("stop", RequestType(func=lambda request, context: self.stop()))
|
||||
rm.add_request("start", RequestType(func=lambda request, context: self.start()))
|
||||
rm.add_request("pause", RequestType(func=lambda request, context: self.pause()))
|
||||
rm.add_request("resume", RequestType(func=lambda request, context: self.resume()))
|
||||
rm.add_request("restart", RequestType(func=lambda request, context: self.restart()))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop())))
|
||||
rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start())))
|
||||
rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause())))
|
||||
rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume())))
|
||||
rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart())))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
@@ -106,17 +112,19 @@ class Service(IOSoftware):
|
||||
state["health_state_visible"] = self.health_state_visible.value
|
||||
return state
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self) -> bool:
|
||||
"""Stop the service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
def start(self, **kwargs) -> bool:
|
||||
"""Start the service."""
|
||||
# cant start the service if the node it is on is off
|
||||
if not super()._can_perform_action():
|
||||
return
|
||||
return False
|
||||
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
@@ -124,36 +132,47 @@ class Service(IOSoftware):
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
return True
|
||||
return False
|
||||
|
||||
def pause(self) -> None:
|
||||
def pause(self) -> bool:
|
||||
"""Pause the service."""
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
return True
|
||||
return False
|
||||
|
||||
def resume(self) -> None:
|
||||
def resume(self) -> bool:
|
||||
"""Resume paused service."""
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
return True
|
||||
return False
|
||||
|
||||
def restart(self) -> None:
|
||||
def restart(self) -> bool:
|
||||
"""Restart running service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
self.restart_countdown = self.restart_duration
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self) -> None:
|
||||
def disable(self) -> bool:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
return True
|
||||
|
||||
def enable(self) -> None:
|
||||
def enable(self) -> bool:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -101,20 +102,27 @@ class Software(SimComponent):
|
||||
"Current number of ticks left to patch the software."
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
"compromise",
|
||||
RequestType(
|
||||
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
),
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"patch",
|
||||
RequestType(
|
||||
func=lambda request, context: self.patch(),
|
||||
func=lambda request, context: RequestResponse.from_bool(self.patch()),
|
||||
),
|
||||
)
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
return rm
|
||||
|
||||
def _get_session_details(self, session_id: str) -> Session:
|
||||
@@ -148,7 +156,7 @@ class Software(SimComponent):
|
||||
)
|
||||
return state
|
||||
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> None:
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> bool:
|
||||
"""
|
||||
Assign a new health state to this software.
|
||||
|
||||
@@ -160,6 +168,7 @@ class Software(SimComponent):
|
||||
:type health_state: SoftwareHealthState
|
||||
"""
|
||||
self.health_state_actual = health_state
|
||||
return True
|
||||
|
||||
def install(self) -> None:
|
||||
"""
|
||||
@@ -180,15 +189,18 @@ class Software(SimComponent):
|
||||
"""
|
||||
pass
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Update the observed health status to match the actual health status."""
|
||||
self.health_state_visible = self.health_state_actual
|
||||
return True
|
||||
|
||||
def patch(self) -> None:
|
||||
def patch(self) -> bool:
|
||||
"""Perform a patch on the software."""
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
|
||||
@@ -12,6 +12,8 @@ def test_passing_actions_down(monkeypatch) -> None:
|
||||
sim = Simulation()
|
||||
|
||||
pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0")
|
||||
pc1.start_up_duration = 0
|
||||
pc1.power_on()
|
||||
pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0")
|
||||
srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0")
|
||||
s1 = Switch(hostname="switch1")
|
||||
|
||||
@@ -125,8 +125,8 @@ def test_describe_state_nmne(uc2_network):
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
db_server_nic_state = db_server_nic.describe_state()
|
||||
uc2_network.apply_timestep(timestep=0)
|
||||
assert web_server_nic_state["nmne"] == {}
|
||||
assert db_server_nic_state["nmne"] == {}
|
||||
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}}
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
|
||||
|
||||
# Perform another "DELETE" query
|
||||
db_client.query("DELETE")
|
||||
@@ -135,8 +135,8 @@ def test_describe_state_nmne(uc2_network):
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
db_server_nic_state = db_server_nic.describe_state()
|
||||
uc2_network.apply_timestep(timestep=0)
|
||||
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}}
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
|
||||
assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}}
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}}
|
||||
|
||||
|
||||
def test_capture_nmne_observations(uc2_network):
|
||||
|
||||
0
tests/integration_tests/test_simulation/__init__.py
Normal file
0
tests/integration_tests/test_simulation/__init__.py
Normal file
160
tests/integration_tests/test_simulation/test_request_response.py
Normal file
160
tests/integration_tests/test_simulation/test_request_response.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# some test cases:
|
||||
# 0. test that sending a request to a valid target results in a success
|
||||
# 1. test that sending a request to a component that doesn't exist results in a failure
|
||||
# 2. test that sending a request to a software on a turned-off component results in a failure
|
||||
# 3. test every implemented action under several different situation, some of which should lead to a success and some to a failure.
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from tests.conftest import TestApplication, TestService
|
||||
|
||||
|
||||
def test_successful_application_requests(example_network):
|
||||
net = example_network
|
||||
|
||||
client_1 = net.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(TestApplication)
|
||||
client_1.software_manager.software.get("TestApplication").run()
|
||||
|
||||
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "patch"])
|
||||
assert resp_2 == RequestResponse(status="success", data={})
|
||||
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
|
||||
assert resp_3 == RequestResponse(status="success", data={})
|
||||
|
||||
|
||||
def test_successful_service_requests(example_network):
|
||||
net = example_network
|
||||
server_1 = net.get_node_by_hostname("server_1")
|
||||
server_1.software_manager.install(TestService)
|
||||
|
||||
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
|
||||
for verb in [
|
||||
"disable",
|
||||
"enable",
|
||||
"start",
|
||||
"stop",
|
||||
"start",
|
||||
"restart",
|
||||
"pause",
|
||||
"resume",
|
||||
"compromise",
|
||||
"scan",
|
||||
"patch",
|
||||
]:
|
||||
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
server_1.apply_timestep(timestep=1)
|
||||
# lazily apply timestep 7 times to make absolutely sure any time-based things like restart have a chance to finish
|
||||
|
||||
|
||||
def test_non_existent_requests(example_network):
|
||||
net = example_network
|
||||
resp_1 = net.apply_request(["fake"])
|
||||
assert resp_1.status == "unreachable"
|
||||
resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"])
|
||||
assert resp_2.status == "unreachable"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node_request",
|
||||
[
|
||||
["node", "client_1", "file_system", "folder", "root", "scan"],
|
||||
["node", "client_1", "os", "scan"],
|
||||
["node", "client_1", "service", "DNSClient", "stop"],
|
||||
["node", "client_1", "application", "WebBrowser", "scan"],
|
||||
["node", "client_1", "network_interface", 1, "disable"],
|
||||
],
|
||||
)
|
||||
def test_request_fails_if_node_off(example_network, node_request):
|
||||
"""Test that requests succeed when the node is on, and fail if the node is off."""
|
||||
net = example_network
|
||||
client_1: HostNode = net.get_node_by_hostname("client_1")
|
||||
client_1.shut_down_duration = 0
|
||||
|
||||
assert client_1.operating_state == NodeOperatingState.ON
|
||||
resp_1 = net.apply_request(node_request)
|
||||
assert resp_1.status == "success"
|
||||
|
||||
client_1.power_off()
|
||||
assert client_1.operating_state == NodeOperatingState.OFF
|
||||
resp_2 = net.apply_request(node_request)
|
||||
assert resp_2.status == "failure"
|
||||
|
||||
|
||||
class TestDataManipulationGreenRequests:
|
||||
def test_node_off(self, uc2_network):
|
||||
"""Test that green requests succeed when the node is on and fail if the node is off."""
|
||||
net = uc2_network
|
||||
|
||||
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
|
||||
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
|
||||
client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
|
||||
client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"])
|
||||
assert client_1_browser_execute.status == "success"
|
||||
assert client_1_db_client_execute.status == "success"
|
||||
assert client_2_browser_execute.status == "success"
|
||||
assert client_2_db_client_execute.status == "success"
|
||||
|
||||
client_1 = net.get_node_by_hostname("client_1")
|
||||
client_2 = net.get_node_by_hostname("client_2")
|
||||
|
||||
client_1.shut_down_duration = 0
|
||||
client_1.power_off()
|
||||
client_2.shut_down_duration = 0
|
||||
client_2.power_off()
|
||||
|
||||
client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
|
||||
client_1_db_client_execute_off = net.apply_request(
|
||||
["node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
)
|
||||
client_2_browser_execute_off = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
|
||||
client_2_db_client_execute_off = net.apply_request(
|
||||
["node", "client_2", "application", "DatabaseClient", "execute"]
|
||||
)
|
||||
assert client_1_browser_execute_off.status == "failure"
|
||||
assert client_1_db_client_execute_off.status == "failure"
|
||||
assert client_2_browser_execute_off.status == "failure"
|
||||
assert client_2_db_client_execute_off.status == "failure"
|
||||
|
||||
def test_acl_block(self, uc2_network):
|
||||
"""Test that green requests succeed when not blocked by ACLs but fail when blocked."""
|
||||
net = uc2_network
|
||||
|
||||
router: Router = net.get_node_by_hostname("router_1")
|
||||
client_1: HostNode = net.get_node_by_hostname("client_1")
|
||||
client_2: HostNode = net.get_node_by_hostname("client_2")
|
||||
|
||||
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
|
||||
client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
|
||||
assert client_1_browser_execute.status == "success"
|
||||
assert client_2_browser_execute.status == "success"
|
||||
|
||||
router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
|
||||
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
|
||||
client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"])
|
||||
assert client_1_browser_execute.status == "failure"
|
||||
assert client_2_browser_execute.status == "failure"
|
||||
|
||||
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
|
||||
client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"])
|
||||
assert client_1_db_client_execute.status == "success"
|
||||
assert client_2_db_client_execute.status == "success"
|
||||
|
||||
router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
|
||||
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
|
||||
client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"])
|
||||
assert client_1_db_client_execute.status == "failure"
|
||||
assert client_2_db_client_execute.status == "failure"
|
||||
0
tests/unit_tests/_primaite/_interface/__init__.py
Normal file
0
tests/unit_tests/_primaite/_interface/__init__.py
Normal file
32
tests/unit_tests/_primaite/_interface/test_request.py
Normal file
32
tests/unit_tests/_primaite/_interface/test_request.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
|
||||
|
||||
def test_creating_response_object():
|
||||
"""Test that we can create a response object with given parameters."""
|
||||
r1 = RequestResponse(status="success", data={"test_data": 1, "other_data": 2})
|
||||
r2 = RequestResponse(status="unreachable")
|
||||
r3 = RequestResponse(data={"test_data": "is_good"})
|
||||
r4 = RequestResponse()
|
||||
assert isinstance(r1, RequestResponse)
|
||||
assert isinstance(r2, RequestResponse)
|
||||
assert isinstance(r3, RequestResponse)
|
||||
assert isinstance(r4, RequestResponse)
|
||||
|
||||
|
||||
def test_creating_response_from_boolean():
|
||||
"""Test that we can build a response with a single boolean."""
|
||||
r1 = RequestResponse.from_bool(True)
|
||||
assert r1.status == "success"
|
||||
|
||||
r2 = RequestResponse.from_bool(False)
|
||||
assert r2.status == "failure"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
r3 = RequestResponse.from_bool(1)
|
||||
with pytest.raises(ValidationError):
|
||||
r4 = RequestResponse.from_bool("good")
|
||||
with pytest.raises(ValidationError):
|
||||
r5 = RequestResponse.from_bool({"data": True})
|
||||
Reference in New Issue
Block a user