Merged PR 184: Connect up actions in the simulator

## Summary
Finishes (?) the action system.
- Defines all UC2 common action space actions on SimComponents
- Links up SimComponents and their children, so actions can be passed from parent to child
- Add a function for enumerating all possible actions that exist on a SimComponent. (will be used for generating action space)
- add documentation for action management

note: I know that the way I approached this is a bit convoluted but It's just what I came up with to allow the actions to be as flexible and modular as the SimComponents themselves.

## Test process
Tested that the functionality works in scratch notebook. But also in the process of adding unit/integration tests now.

## Checklist
- [x] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [x] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #1923
This commit is contained in:
Marek Wolan
2023-10-12 09:00:30 +00:00
20 changed files with 651 additions and 480 deletions

3
.gitignore vendored
View File

@@ -152,4 +152,5 @@ simulation_output/
# benchmark session outputs
benchmark/output
src/primaite/notebooks/scratch.ipynb
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py

View File

@@ -0,0 +1,88 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Actions System
==============
``SimComponent``s in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``Action``.
Just like other aspects of SimComponent, the actions are not managed centrally for the whole simulation, but instead they are dynamically created and updated based on the nodes, links, and other components that currently exist. This was achieved with the following design decisions:
- API
An 'action' contains two elements:
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
2. ``context`` - optional extra information that can be used to decide how to process the action. This is formatted as a dictionary. For example, if the action requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient.
- request
The request is a list of strings which help specify who should handle the request. The strings in the request list help RequestManagers traverse the 'ownership tree' of SimComponent. The example given above would be handled in the following way:
1. ``Simulation`` receives `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
The first element of the action is ``network``, therefore it passes the action down to its network.
2. ``Network`` receives `['node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
The first element of the action is ``node``, therefore the network looks at the node uuid and passes the action down to the node with that uuid.
3. ``Node`` receives `['service', '<service-uuid>', 'restart']`.
The first element of the action is ``service``, therefore the node looks at the service uuid and passes the rest of the action to the service with that uuid.
4. ``Service`` receives ``['restart']``.
Since ``restart`` is a defined action in the service's own RequestManager, the service performs a restart.
Techincal Detail
================
This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.Action`, and :py:class:`primaite.simulator.core.RequestManager`.
Action
------
The ``Action`` object stores a reference to a method that performs the action, for example a node could have an action that stores a reference to ``self.turn_on()``. Techincally, this can be any callable that accepts `request, context` as it's parameters. In practice, this is often defined using ``lambda`` functions within a component's ``self._init_request_manager()`` method. Optionally, the ``Action`` object can also hold a validator that will permit/deny the action depending on context.
RequestManager
-------------
The ``RequestManager`` object stores a mapping between strings and actions. It is responsible for processing the ``request`` and passing it down the ownership tree. Techincally, the ``RequestManager`` is itself a callable that accepts `request, context` tuple, and so it can be chained with other action managers.
A simple example without chaining can be seen in the :py:class:`primaite.simulator.file_system.file_system.File` class.
.. code-block:: python
class File(FileSystemItemABC):
...
def _init_request_manager(self):
...
request_manager.add_action("scan", Action(func=lambda request, context: self.scan()))
request_manager.add_action("repair", Action(func=lambda request, context: self.repair()))
request_manager.add_action("restore", Action(func=lambda request, context: self.restore()))
*ellipses (``...``) used to omit code impertinent to this explanation*
Chaining RequestManagers
-----------------------
Since the method for performing an action needs to accept `request, context` as parameters, and RequestManager itself is a callable that accepts `request, context` as parameters, it possible to use RequestManager as an action. In fact, that is how PrimAITE deals with traversing the ownership tree. Each time an RequestManager accepts a request, it pops the first elements and uses it to decide to which Action it should send the remaining request. However, the Action could have another RequestManager as it's function, therefore the request will be routed again. Each time the request is passed to a new action manager, the first element is popped.
An example of how this works is in the :py:class:`primaite.simulator.network.hardware.base.Node` class.
.. code-block:: python
class Node(SimComponent):
...
def _init_request_manager(self):
...
# a regular action which is processed by the Node itself
request_manager.add_action("turn_on", Action(func=lambda request, context: self.turn_on()))
# if the Node receives a request where the first word is 'service', it will use a dummy manager
# called self._service_request_manager to pass on the reqeust to the relevant service. This dummy
# manager is simply here to map the service UUID that that service's own action manager. This is
# done because the next string after "service" is always the uuid of that service, so we need an
# RequestManager to pop that string before sending it onto the relevant service's RequestManager.
self._service_request_manager = RequestManager()
request_manager.add_action("service", Action(func=self._service_request_manager))
...
def install_service(self, service):
self.services[service.uuid] = service
...
# Here, the service UUID is registered to allow passing actions between the node and the service.
self._service_request_manager.add_action(service.uuid, Action(func=service._request_manager))

View File

@@ -23,3 +23,4 @@ Contents
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/software
action_system

View File

@@ -42,15 +42,15 @@ snippet demonstrates usage of the ``ActionPermissionValidator``.
.. code:: python
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import Action, RequestManager, SimComponent
from primaite.simulator.domain.controller import AccountGroup, GroupMembershipValidator
class Smartphone(SimComponent):
name: str
apps = []
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_action(
"reset_factory_settings",
Action(

View File

@@ -0,0 +1,107 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.networks import arcd_uc2_network\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = arcd_uc2_network()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### set up some services to test if actions are working"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv = net.get_node_by_hostname('database_server')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.system.services.database_service import DatabaseService"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_svc = DatabaseService(file_system=db_serv.file_system)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv.install_service(db_svc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv.describe_state()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
# flake8: noqa
"""Core of the PrimAITE Simulator."""
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional
from typing import Callable, ClassVar, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel, ConfigDict
@@ -10,9 +11,9 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
class ActionPermissionValidator(ABC):
class RequestPermissionValidator(BaseModel):
"""
Base class for action validators.
Base class for request validators.
The permissions manager is designed to be generic. So, although in the first instance the permissions
are evaluated purely on membership to AccountGroup, this class can support validating permissions based on any
@@ -21,109 +22,126 @@ class ActionPermissionValidator(ABC):
@abstractmethod
def __call__(self, request: List[str], context: Dict) -> bool:
"""Use the request and context paramters to decide whether the action should be permitted."""
"""Use the request and context paramters to decide whether the request should be permitted."""
pass
class AllowAllValidator(ActionPermissionValidator):
"""Always allows the action."""
class AllowAllValidator(RequestPermissionValidator):
"""Always allows the request."""
def __call__(self, request: List[str], context: Dict) -> bool:
"""Always allow the action."""
"""Always allow the request."""
return True
class Action:
class RequestType(BaseModel):
"""
This object stores data related to a single action.
This object stores data related to a single request type.
This includes the callable that can execute the action request, and the validator that will decide whether
the action can be performed or not.
This includes the callable that can execute the request, and the validator that will decide whether
the request can be performed or not.
"""
def __init__(
self, func: Callable[[List[str], Dict], None], validator: ActionPermissionValidator = AllowAllValidator()
) -> None:
func: Callable[[List[str], Dict], None]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this request will be given something like ``func = lambda request, context: self.turn_off()``.
``func`` can also be another request manager, since RequestManager is a callable with a signature that matches what is
expected by ``func``.
"""
validator: RequestPermissionValidator = AllowAllValidator()
"""
``validator`` is an instance of ``RequestPermissionValidator``. This is essentially a callable that
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
the request. The default validator will allow
"""
class RequestManager(BaseModel):
"""
RequestManager is used by `SimComponent` instances to keep track of requests.
Its main purpose is to be a lookup from request name to request function and corresponding validation function. This
class is responsible for providing a consistent API for processing requests as well as helpful error messages.
"""
request_types: Dict[str, RequestType] = {}
"""maps request name to an RequestType object."""
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
"""
Save the functions that are for this action.
Process an request request.
Here's a description for the intended use of both of these.
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
``validator`` is an instance of a subclass of `ActionPermissionValidator`. This is essentially a callable that
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
the action.
:param func: Function that performs the request.
:type func: Callable[[List[str], Dict], None]
:param validator: Function that checks if the request is authenticated given the context. By default, if no
validator is provided, an 'allow all' validator is added which permits all requests.
:type validator: ActionPermissionValidator
"""
self.func: Callable[[List[str], Dict], None] = func
self.validator: ActionPermissionValidator = validator
class ActionManager:
"""
ActionManager is used by `SimComponent` instances to keep track of actions.
Its main purpose is to be a lookup from action name to action function and corresponding validation function. This
class is responsible for providing a consistent API for processing actions as well as helpful error messages.
"""
def __init__(self) -> None:
"""Initialise ActionManager with an empty action lookup."""
self.actions: Dict[str, Action] = {}
def process_request(self, request: List[str], context: Dict) -> None:
"""Process an action request.
:param request: A list of strings which specify what action to take. The first string must be one of the allowed
actions, i.e. it must be a key of self.actions. The subsequent strings in the list are passed as parameters
to the action function.
:param request: A list of strings describing the request. The first string must be one of the allowed
request names, i.e. it must be a key of self.request_types. The subsequent strings in the list are passed as
parameters to the request function.
:type request: List[str]
:param context: Dictionary of additional information necessary to process or validate the request.
:type context: Dict
:raises RuntimeError: If the request parameter does not have a valid action identifier as the first item.
:raises RuntimeError: If the request parameter does not have a valid request name as the first item.
"""
action_key = request[0]
request_key = request[0]
if action_key not in self.actions:
if request_key not in self.request_types:
msg = (
f"Action request {request} could not be processed because {action_key} is not a valid action",
"within this ActionManager",
f"Request {request} could not be processed because {request_key} is not a valid request name",
"within this RequestManager",
)
_LOGGER.error(msg)
raise RuntimeError(msg)
action = self.actions[action_key]
action_options = request[1:]
request_type = self.request_types[request_key]
request_options = request[1:]
if not action.validator(action_options, context):
_LOGGER.debug(f"Action request {request} was denied due to insufficient permissions")
if not request_type.validator(request_options, context):
_LOGGER.debug(f"Request {request} was denied due to insufficient permissions")
return
action.func(action_options, context)
request_type.func(request_options, context)
def add_action(self, name: str, action: Action) -> None:
"""Add an action to this action manager.
:param name: The string associated to this action.
:type name: str
:param action: Action object.
:type action: Action
def add_request(self, name: str, request_type: RequestType) -> None:
"""
if name in self.actions:
msg = f"Attempted to register an action but the action name {name} is already taken."
Add a request type to this request manager.
:param name: The string associated to this request.
:type name: str
:param request_type: Request type object which contains information about how to resolve request.
:type request_type: RequestType
"""
if name in self.request_types:
msg = f"Overwriting request type {name}."
_LOGGER.warn(msg)
self.request_types[name] = request_type
def remove_request(self, name: str) -> None:
"""
Remove a request from this manager.
:param name: name identifier of the request
:type name: str
"""
if name not in self.request_types:
msg = f"Attempted to remove request {name} from request manager, but it was not registered."
_LOGGER.error(msg)
raise RuntimeError(msg)
self.actions[name] = action
self.request_types.pop(name)
def get_request_types_recursively(self) -> List[List[str]]:
"""Recursively generate request tree for this component."""
requests = []
for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively()
sub_requests = [[req_name] + a for a in sub_requests]
requests.extend(sub_requests)
else:
requests.append([req_name])
return requests
class SimComponent(BaseModel):
@@ -139,30 +157,30 @@ class SimComponent(BaseModel):
if not kwargs.get("uuid"):
kwargs["uuid"] = str(uuid4())
super().__init__(**kwargs)
self._action_manager: ActionManager = self._init_action_manager()
self.parent: Optional["SimComponent"] = None
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
def _init_action_manager(self) -> ActionManager:
def _init_request_manager(self) -> RequestManager:
"""
Initialise the action manager for this component.
Initialise the request manager for this component.
When using a hierarchy of components, the child classes should call the parent class's _init_action_manager and
add additional actions on top of the existing generic ones.
When using a hierarchy of components, the child classes should call the parent class's _init_request_manager and
add additional requests on top of the existing generic ones.
Example usage for inherited classes:
..code::python
class WebBrowser(Application):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager() # all actions generic to any Application get initialised
am.add_action(...) # initialise any actions specific to the web browser
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager() # all requests generic to any Application get initialised
am.add_request(...) # initialise any requests specific to the web browser
return am
:return: Actiona manager object belonging to this sim component.
:rtype: ActionManager
:return: Request manager object belonging to this sim component.
:rtype: RequestManager
"""
return ActionManager()
return RequestManager()
@abstractmethod
def describe_state(self) -> Dict:
@@ -178,27 +196,27 @@ class SimComponent(BaseModel):
}
return state
def apply_action(self, action: List[str], context: Dict = {}) -> None:
def apply_request(self, request: List[str], context: Dict = {}) -> None:
"""
Apply an action to a simulation component. Action data is passed in as a 'namespaced' list of strings.
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
If the list only has one element, the action is intended to be applied directly to this object. If the list has
multiple entries, the action is passed to the child of this object specified by the first one or two entries.
If the list only has one element, the request is intended to be applied directly to this object. If the list has
multiple entries, the request is passed to the child of this object specified by the first one or two entries.
This is essentially a namespace.
For example, ["turn_on",] is meant to apply an action of 'turn on' to this component.
For example, ["turn_on",] is meant to apply a request of 'turn on' to this component.
However, ["services", "email_client", "turn_on"] is meant to 'turn on' this component's email client service.
:param action: List describing the action to apply to this object.
:type action: List[str]
:param request: List describing the request to apply to this object.
:type request: List[str]
:param: context: Dict containing context for actions
:param: context: Dict containing context for requests
:type context: Dict
"""
if self.action_manager is None:
if self._request_manager is None:
return
self.action_manager.process_request(action, context)
self._request_manager(request, context)
def apply_timestep(self, timestep: int) -> None:
"""
@@ -216,3 +234,20 @@ class SimComponent(BaseModel):
Override this method with anything that needs to happen within the component for it to be reset.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.
:return: Parent object.
:rtype: SimComponent
"""
return self._parent
@parent.setter
def parent(self, new_parent: Union["SimComponent", None]) -> None:
if self._parent and new_parent:
msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}"
_LOGGER.warn(msg)
raise RuntimeWarning(msg)
self._parent = new_parent

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Dict, Final, List, Literal, Tuple
from primaite.simulator.core import Action, ActionManager, ActionPermissionValidator, SimComponent
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
from primaite.simulator.domain.account import Account, AccountType
@@ -43,16 +43,10 @@ class AccountGroup(Enum):
"For full access"
class GroupMembershipValidator(ActionPermissionValidator):
class GroupMembershipValidator(RequestPermissionValidator):
"""Permit actions based on group membership."""
def __init__(self, allowed_groups: List[AccountGroup]) -> None:
"""Store a list of groups that should be granted permission.
:param allowed_groups: List of AccountGroups that are permitted to perform some action.
:type allowed_groups: List[AccountGroup]
"""
self.allowed_groups = allowed_groups
allowed_groups: List[AccountGroup]
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
@@ -85,15 +79,15 @@ class DomainController(SimComponent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
# Action 'account' matches requests like:
# ['account', '<account-uuid>', *account_action]
am.add_action(
am.add_request(
"account",
Action(
func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context),
validator=GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]),
RequestType(
func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context),
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
),
)
return am

View File

@@ -9,7 +9,7 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
from primaite.simulator.system.core.sys_log import SysLog
@@ -94,6 +94,17 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
self._folder_request_manager = RequestManager()
am.add_request("folder", RequestType(func=self._folder_request_manager))
self._file_request_manager = RequestManager()
am.add_request("file", RequestType(func=self._file_request_manager))
return am
@property
def size(self) -> int:
"""
@@ -154,6 +165,7 @@ class FileSystem(SimComponent):
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self.sys_log.info(f"Created folder /{folder.name}")
self._folder_request_manager.add_request(folder.uuid, RequestType(func=folder._request_manager))
return folder
def delete_folder(self, folder_name: str):
@@ -172,6 +184,7 @@ class FileSystem(SimComponent):
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
self._folder_request_manager.remove_request(folder.uuid)
else:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
@@ -213,6 +226,7 @@ class FileSystem(SimComponent):
)
folder.add_file(file)
self.sys_log.info(f"Created file /{file.path}")
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
return file
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
@@ -240,6 +254,7 @@ class FileSystem(SimComponent):
file = folder.get_file(file_name)
if file:
folder.remove_file(file)
self._file_request_manager.remove_request(file.uuid)
self.sys_log.info(f"Deleted file /{file.path}")
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
@@ -317,6 +332,18 @@ class Folder(FileSystemItemABC):
is_quarantined: bool = False
"Flag that marks the folder as quarantined if true."
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request
return am
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -482,6 +509,18 @@ class File(FileSystemItemABC):
with open(self.sim_path, mode="a"):
pass
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request
return am
def make_copy(self, dst_folder: Folder) -> File:
"""
Create a copy of the current File object in the given destination folder.

View File

@@ -6,7 +6,7 @@ from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
@@ -37,21 +37,18 @@ class Network(SimComponent):
Initialise the network.
Constructs the network and sets up its initial state including
the action manager and an empty MultiGraph for topology representation.
the request manager and an empty MultiGraph for topology representation.
"""
super().__init__(**kwargs)
self._nx_graph = MultiGraph()
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action(
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
self._node_request_manager = RequestManager()
am.add_request(
"node",
Action(
func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
validator=AllowAllValidator(),
),
RequestType(func=self._node_request_manager),
)
return am
@@ -184,7 +181,8 @@ class Network(SimComponent):
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.uuid, request_type=RequestType(func=node._request_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -218,6 +216,7 @@ class Network(SimComponent):
break
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_request_manager.remove_request(name=node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""

View File

@@ -12,7 +12,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
@@ -89,9 +89,9 @@ class NIC(SimComponent):
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
wake_on_lan: bool = False
"Indicates if the NIC supports Wake-on-LAN functionality."
connected_node: Optional[Node] = None
_connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
connected_link: Optional[Link] = None
_connected_link: Optional[Link] = None
"The Link to which the NIC is connected."
enabled: bool = False
"Indicates whether the NIC is enabled."
@@ -135,17 +135,23 @@ class NIC(SimComponent):
{
"ip_adress": str(self.ip_address),
"subnet_mask": str(self.subnet_mask),
"gateway": str(self.gateway),
"mac_address": self.mac_address,
"speed": self.speed,
"mtu": self.mtu,
"wake_on_lan": self.wake_on_lan,
"dns_servers": self.dns_servers,
"enabled": self.enabled,
}
)
return state
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("enable", RequestType(func=lambda request, context: self.enable()))
am.add_request("disable", RequestType(func=lambda request, context: self.disable()))
return am
@property
def ip_network(self) -> IPv4Network:
"""
@@ -159,21 +165,21 @@ class NIC(SimComponent):
"""Attempt to enable the NIC."""
if self.enabled:
return
if not self.connected_node:
if not self._connected_node:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
return
if not self.connected_link:
if not self._connected_link:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
return
self.enabled = True
self.connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, ip_address=self.ip_address)
if self.connected_link:
self.connected_link.endpoint_up()
self._connected_node.sys_log.info(f"NIC {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, ip_address=self.ip_address)
if self._connected_link:
self._connected_link.endpoint_up()
def disable(self):
"""Disable the NIC."""
@@ -181,12 +187,12 @@ class NIC(SimComponent):
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
if self._connected_node:
self._connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.debug(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if self._connected_link:
self._connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -195,26 +201,26 @@ class NIC(SimComponent):
:param link: The link to which the NIC is connected.
:type link: :class:`~primaite.simulator.network.transmission.physical_layer.Link`
"""
if self.connected_link:
if self._connected_link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it already has a connection")
return
if self.connected_link == link:
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to NIC ({self.mac_address}) as it is already connected")
return
# TODO: Inform the Node that a link has been connected
self.connected_link = link
self._connected_link = link
self.enable()
_LOGGER.debug(f"NIC {self} connected to Link {link}")
def disconnect_link(self):
"""Disconnect the NIC from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
self.connected_link.endpoint_b = None
self.connected_link = None
if self._connected_link.endpoint_a == self:
self._connected_link.endpoint_a = None
if self._connected_link.endpoint_b == self:
self._connected_link.endpoint_b = None
self._connected_link = None
def add_dns_server(self, ip_address: IPv4Address):
"""
@@ -244,7 +250,7 @@ class NIC(SimComponent):
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the NIC is not enabled
return False
@@ -263,7 +269,7 @@ class NIC(SimComponent):
self.pcap.capture(frame)
# If this destination or is broadcast
if frame.ethernet.dst_mac_addr == self.mac_address or frame.ethernet.dst_mac_addr == "ff:ff:ff:ff:ff:ff":
self.connected_node.receive_frame(frame=frame, from_nic=self)
self._connected_node.receive_frame(frame=frame, from_nic=self)
return True
return False
@@ -288,9 +294,9 @@ class SwitchPort(SimComponent):
"The speed of the SwitchPort in Mbps. Default is 100 Mbps."
mtu: int = 1500
"The Maximum Transmission Unit (MTU) of the SwitchPort in Bytes. Default is 1500 B"
connected_node: Optional[Node] = None
_connected_node: Optional[Node] = None
"The Node to which the SwitchPort is connected."
connected_link: Optional[Link] = None
_connected_link: Optional[Link] = None
"The Link to which the SwitchPort is connected."
enabled: bool = False
"Indicates whether the SwitchPort is enabled."
@@ -327,31 +333,31 @@ class SwitchPort(SimComponent):
if self.enabled:
return
if not self.connected_node:
if not self._connected_node:
_LOGGER.error(f"SwitchPort {self} cannot be enabled as it is not connected to a Node")
return
if self.connected_node.operating_state != NodeOperatingState.ON:
self.connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.info(f"SwitchPort {self} cannot be enabled as the endpoint is not turned on")
return
self.enabled = True
self.connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self.connected_node.hostname, switch_port_number=self.port_num)
if self.connected_link:
self.connected_link.endpoint_up()
self._connected_node.sys_log.info(f"SwitchPort {self} enabled")
self.pcap = PacketCapture(hostname=self._connected_node.hostname, switch_port_number=self.port_num)
if self._connected_link:
self._connected_link.endpoint_up()
def disable(self):
"""Disable the SwitchPort."""
if not self.enabled:
return
self.enabled = False
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
if self._connected_node:
self._connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.debug(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
if self._connected_link:
self._connected_link.endpoint_down()
def connect_link(self, link: Link):
"""
@@ -359,26 +365,26 @@ class SwitchPort(SimComponent):
:param link: The link to which the SwitchPort is connected.
"""
if self.connected_link:
if self._connected_link:
_LOGGER.error(f"Cannot connect link to SwitchPort {self.mac_address} as it already has a connection")
return
if self.connected_link == link:
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to SwitchPort {self.mac_address} as it is already connected")
return
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
self._connected_link = link
_LOGGER.debug(f"SwitchPort {self} connected to Link {link}")
self.enable()
def disconnect_link(self):
"""Disconnect the SwitchPort from the connected Link."""
if self.connected_link.endpoint_a == self:
self.connected_link.endpoint_a = None
if self.connected_link.endpoint_b == self:
self.connected_link.endpoint_b = None
self.connected_link = None
if self._connected_link.endpoint_a == self:
self._connected_link.endpoint_a = None
if self._connected_link.endpoint_b == self:
self._connected_link.endpoint_b = None
self._connected_link = None
def send_frame(self, frame: Frame) -> bool:
"""
@@ -388,7 +394,7 @@ class SwitchPort(SimComponent):
"""
if self.enabled:
self.pcap.capture(frame)
self.connected_link.transmit_frame(sender_nic=self, frame=frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the SwitchPort is not enabled
return False
@@ -404,7 +410,7 @@ class SwitchPort(SimComponent):
if self.enabled:
frame.decrement_ttl()
self.pcap.capture(frame)
connected_node: Node = self.connected_node
connected_node: Node = self._connected_node
connected_node.forward_frame(frame=frame, incoming_port=self)
return True
return False
@@ -940,9 +946,36 @@ class Node(SimComponent):
super().__init__(**kwargs)
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
self._install_system_software()
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
am = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
self._service_request_manager = RequestManager()
am.add_request("service", RequestType(func=self._service_request_manager))
self._nic_request_manager = RequestManager()
am.add_request("nic", RequestType(func=self._nic_request_manager))
am.add_request("file_system", RequestType(func=self.file_system._request_manager))
# currently we don't have any applications nor processes, so these will be empty
self._process_request_manager = RequestManager()
am.add_request("process", RequestType(func=self._process_request_manager))
self._application_request_manager = RequestManager()
am.add_request("application", RequestType(func=self._application_request_manager))
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement OS scan
am.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
am.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
am.add_request("reset", RequestType(func=lambda request, context: ...)) # TODO implement node reset
am.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
am.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
return am
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
@@ -1014,7 +1047,7 @@ class Node(SimComponent):
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
for nic in self.nics.values():
if nic.connected_link:
if nic._connected_link:
nic.enable()
def power_off(self):
@@ -1035,11 +1068,12 @@ class Node(SimComponent):
if nic.uuid not in self.nics:
self.nics[nic.uuid] = nic
self.ethernet_port[len(self.nics)] = nic
nic.connected_node = self
nic._connected_node = self
nic.parent = self
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:
nic.enable()
self._nic_request_manager.add_request(nic.uuid, RequestType(func=nic._request_manager))
else:
msg = f"Cannot connect NIC {nic} as it is already connected"
self.sys_log.logger.error(msg)
@@ -1064,6 +1098,7 @@ class Node(SimComponent):
nic.parent = None
nic.disable()
self.sys_log.info(f"Disconnected NIC {nic}")
self._nic_request_manager.remove_request(nic.uuid)
else:
msg = f"Cannot disconnect NIC {nic} as it is not connected"
self.sys_log.logger.error(msg)
@@ -1160,7 +1195,8 @@ class Node(SimComponent):
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.debug(f"Added service {service.uuid} to node {self.uuid}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
@@ -1175,7 +1211,8 @@ class Node(SimComponent):
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
_LOGGER.debug(f"Removed service {service.uuid} from node {self.uuid}")
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
self._service_request_manager.remove_request(service.uuid)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
@@ -1198,7 +1235,7 @@ class Switch(Node):
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port._connected_node = self
port.parent = self
port.port_num = port_num
@@ -1271,7 +1308,7 @@ class Switch(Node):
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)

View File

@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
@@ -43,7 +43,7 @@ class ACLRule(SimComponent):
def __str__(self) -> str:
rule_strings = []
for key, value in self.model_dump(exclude={"uuid", "action_manager"}).items():
for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items():
if value is None:
value = "ANY"
if isinstance(value, Enum):
@@ -87,6 +87,36 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
# POSITIONAL ARGUMENTS:
# 0: action (str name of an ACLAction)
# 1: protocol (str name of an IPProtocol)
# 2: source ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 3: source port (str name of a Port (e.g. "HTTP")) # should we be using value, such as 80 or 443?
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 5: destination port (str name of a Port (e.g. "HTTP"))
# 6: position (int)
am.add_request(
"add_rule",
RequestType(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
IPProtocol[request[1]],
IPv4Address[request[2]],
Port[request[3]],
IPv4Address[request[4]],
Port[request[5]],
int(request[6]),
)
),
)
am.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
return am
def describe_state(self) -> Dict:
"""
Describes the current state of the AccessControlList.
@@ -596,6 +626,11 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("acl", RequestType(func=self.acl._request_manager))
return am
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:
"""
Retrieve the port number for a given NIC.

View File

@@ -30,7 +30,7 @@ class Switch(Node):
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port.connected_node = self
port._connected_node = self
port.parent = self
port.port_num = port_num
@@ -113,7 +113,7 @@ class Switch(Node):
_LOGGER.error(msg)
raise NetworkError(msg)
if port.connected_link != link:
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)

View File

@@ -1,6 +1,6 @@
from typing import Dict
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.controller import DomainController
from primaite.simulator.network.container import Network
@@ -21,22 +21,12 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# pass through network actions to the network objects
am.add_action(
"network",
Action(
func=lambda request, context: self.network.apply_action(request, context), validator=AllowAllValidator()
),
)
# pass through domain actions to the domain object
am.add_action(
"domain",
Action(
func=lambda request, context: self.domain.apply_action(request, context), validator=AllowAllValidator()
),
)
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
# pass through network requests to the network objects
am.add_request("network", RequestType(func=self.network._request_manager))
# pass through domain requests to the domain object
am.add_request("domain", RequestType(func=self.domain._request_manager))
return am
def describe_state(self) -> Dict:

View File

@@ -2,7 +2,7 @@ from enum import Enum
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.software import IOSoftware
_LOGGER = getLogger(__name__)
@@ -39,15 +39,15 @@ class Service(IOSoftware):
_restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("stop", Action(func=lambda request, context: self.stop()))
am.add_action("start", Action(func=lambda request, context: self.start()))
am.add_action("pause", Action(func=lambda request, context: self.pause()))
am.add_action("resume", Action(func=lambda request, context: self.resume()))
am.add_action("restart", Action(func=lambda request, context: self.restart()))
am.add_action("disable", Action(func=lambda request, context: self.disable()))
am.add_action("enable", Action(func=lambda request, context: self.enable()))
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("stop", RequestType(func=lambda request, context: self.stop()))
am.add_request("start", RequestType(func=lambda request, context: self.start()))
am.add_request("pause", RequestType(func=lambda request, context: self.pause()))
am.add_request("resume", RequestType(func=lambda request, context: self.resume()))
am.add_request("restart", RequestType(func=lambda request, context: self.restart()))
am.add_request("disable", RequestType(func=lambda request, context: self.disable()))
am.add_request("enable", RequestType(func=lambda request, context: self.enable()))
return am
def describe_state(self) -> Dict:

View File

@@ -3,7 +3,7 @@ from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
@@ -87,15 +87,15 @@ class Software(SimComponent):
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action(
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request(
"compromise",
Action(
RequestType(
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
),
)
am.add_action("scan", Action(func=lambda request, context: self.scan()))
am.add_request("scan", RequestType(func=lambda request, context: self.scan()))
return am
def _get_session_details(self, session_id: str) -> Session:
@@ -214,7 +214,7 @@ class IOSoftware(Software):
"max_sessions": self.max_sessions,
"tcp": self.tcp,
"udp": self.udp,
"ports": [port.name for port in self.ports], # TODO: not sure if this should be port.name or port.value
"port": self.port.value,
}
)
return state

View File

@@ -0,0 +1,55 @@
import pytest
from primaite.simulator.core import RequestType
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.services.database.database_service import DatabaseService
def test_passing_actions_down(monkeypatch) -> None:
"""Check that an action is passed down correctly to the child component."""
sim = Simulation()
pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0")
pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0")
srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0")
s1 = Switch(hostname="switch1")
for n in [pc1, pc2, srv, s1]:
sim.network.add_node(n)
database_service = DatabaseService(file_system=srv.file_system)
srv.install_service(database_service)
downloads_folder = pc1.file_system.create_folder("downloads")
pc1.file_system.create_file("bermuda_triangle.png", folder_name="downloads")
sim.network.connect(pc1.ethernet_port[1], s1.switch_ports[1])
sim.network.connect(pc2.ethernet_port[1], s1.switch_ports[2])
sim.network.connect(s1.switch_ports[3], srv.ethernet_port[1])
# call this method to make sure no errors occur.
sim._request_manager.get_request_types_recursively()
# patch the action to do something which we can check the result of.
action_invoked = False
def succeed():
nonlocal action_invoked
action_invoked = True
monkeypatch.setitem(
downloads_folder._request_manager.request_types, "repair", RequestType(func=lambda request, context: succeed())
)
assert not action_invoked
# call the patched method
sim.apply_request(
["network", "node", pc1.uuid, "file_system", "folder", pc1.file_system.get_folder("downloads").uuid, "repair"]
)
assert action_invoked

View File

@@ -3,10 +3,11 @@ from typing import Dict, List, Literal
import pytest
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import AllowAllValidator, RequestManager, RequestType, SimComponent
from primaite.simulator.domain.controller import AccountGroup, GroupMembershipValidator
@pytest.mark.skip(reason="Action validation is not currently a required feature.")
def test_group_action_validation() -> None:
"""
Check that actions are denied when an unauthorised request is made.
@@ -28,11 +29,11 @@ def test_group_action_validation() -> None:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self._request_manager = RequestManager()
self.action_manager.add_action(
self._request_manager.add_request(
"create_folder",
Action(
RequestType(
func=lambda request, context: self.create_folder(request[0]),
validator=GroupMembershipValidator([AccountGroup.LOCAL_ADMIN, AccountGroup.DOMAIN_ADMIN]),
),
@@ -51,17 +52,18 @@ def test_group_action_validation() -> None:
# check that the folder is created when a local admin tried to do it
permitted_context = {"request_source": {"agent": "BLUE", "account": "User1", "groups": ["LOCAL_ADMIN"]}}
my_node = Node(uuid="0000-0000-1234", name="pc")
my_node.apply_action(["create_folder", "memes"], context=permitted_context)
my_node.apply_request(["create_folder", "memes"], context=permitted_context)
assert len(my_node.folders) == 1
assert my_node.folders[0].name == "memes"
# check that the number of folders is still 1 even after attempting to create a second one without permissions
invalid_context = {"request_source": {"agent": "BLUE", "account": "User1", "groups": ["LOCAL_USER", "DOMAIN_USER"]}}
my_node.apply_action(["create_folder", "memes2"], context=invalid_context)
my_node.apply_request(["create_folder", "memes2"], context=invalid_context)
assert len(my_node.folders) == 1
assert my_node.folders[0].name == "memes"
@pytest.mark.skip(reason="Action validation is not currently a required feature.")
def test_hierarchical_action_with_validation() -> None:
"""
Check that validation works with sub-objects.
@@ -77,32 +79,32 @@ def test_hierarchical_action_with_validation() -> None:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.request_manager = RequestManager()
self.action_manager.add_action(
self.request_manager.add_request(
"turn_on",
Action(
RequestType(
func=lambda request, context: self.turn_on(),
validator=AllowAllValidator(),
),
)
self.action_manager.add_action(
self.request_manager.add_request(
"turn_off",
Action(
RequestType(
func=lambda request, context: self.turn_off(),
validator=AllowAllValidator(),
),
)
self.action_manager.add_action(
self.request_manager.add_request(
"disable",
Action(
RequestType(
func=lambda request, context: self.disable(),
validator=GroupMembershipValidator([AccountGroup.LOCAL_ADMIN, AccountGroup.DOMAIN_ADMIN]),
),
)
self.action_manager.add_action(
self.request_manager.add_request(
"enable",
Action(
RequestType(
func=lambda request, context: self.enable(),
validator=GroupMembershipValidator([AccountGroup.LOCAL_ADMIN, AccountGroup.DOMAIN_ADMIN]),
),
@@ -133,11 +135,11 @@ def test_hierarchical_action_with_validation() -> None:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.action_manager = ActionManager()
self.request_manager = RequestManager()
self.action_manager.add_action(
self.request_manager.add_request(
"apps",
Action(
RequestType(
func=lambda request, context: self.send_action_to_app(request.pop(0), request, context),
validator=AllowAllValidator(),
),
@@ -153,7 +155,7 @@ def test_hierarchical_action_with_validation() -> None:
def send_action_to_app(self, app_name: str, options: List[str], context: Dict):
for app in self.apps:
if app_name == app.name:
app.apply_action(options, context)
app.apply_request(options, context)
break
else:
msg = f"Node has no app with name {app_name}"
@@ -176,15 +178,15 @@ def test_hierarchical_action_with_validation() -> None:
}
# check that a non-admin can't disable this app
my_node.apply_action(["apps", "Chrome", "disable"], non_admin_context)
my_node.apply_request(["apps", "Chrome", "disable"], non_admin_context)
assert my_node.apps[0].name == "Chrome" # if failure occurs on this line, the test itself is broken
assert my_node.apps[0].state == "off"
# check that a non-admin can turn this app on
my_node.apply_action(["apps", "Firefox", "turn_on"], non_admin_context)
my_node.apply_request(["apps", "Firefox", "turn_on"], non_admin_context)
assert my_node.apps[1].name == "Firefox" # if failure occurs on this line, the test itself is broken
assert my_node.apps[1].state == "on"
# check that an admin can disable this app
my_node.apply_action(["apps", "Chrome", "disable"], admin_context)
my_node.apply_request(["apps", "Chrome", "disable"], admin_context)
assert my_node.apps[0].state == "disabled"

View File

@@ -70,8 +70,8 @@ def test_connecting_node_to_itself():
net.connect(node.nics[nic1.uuid], node.nics[nic2.uuid], bandwidth=30)
assert node in net
assert nic1.connected_link is None
assert nic2.connected_link is None
assert nic1._connected_link is None
assert nic2._connected_link is None
assert len(net.links) == 0

View File

@@ -13,6 +13,6 @@ def test_account_deserialise():
"""Test that an account can be deserialised. The test fails if pydantic throws an error."""
acct_json = (
'{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,'
'"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"action_manager":null}'
'"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"request_manager":null}'
)
acct = Account.model_validate_json(acct_json)