Merge remote-tracking branch 'origin/dev' into feature/2350-confirm-action-observation-space-conforms-to-CAOS-0.7
This commit is contained in:
@@ -492,9 +492,9 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
"add_rule",
|
||||
permission_str,
|
||||
protocol,
|
||||
src_ip,
|
||||
str(src_ip),
|
||||
src_port,
|
||||
dst_ip,
|
||||
str(dst_ip),
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
|
||||
@@ -170,9 +170,13 @@ class PrimaiteGame:
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent_actions[agent.agent_name] = {
|
||||
"action": action_choice,
|
||||
"parameters": options,
|
||||
"response": response.model_dump(),
|
||||
}
|
||||
return agent_actions
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
|
||||
0
src/primaite/interface/__init__.py
Normal file
0
src/primaite/interface/__init__.py
Normal file
44
src/primaite/interface/request.py
Normal file
44
src/primaite/interface/request.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
|
||||
class RequestResponse(BaseModel):
|
||||
"""Schema for generic request responses."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
"""Cannot have extra fields in the response. Anything custom goes into the data field."""
|
||||
|
||||
status: Literal["pending", "success", "failure", "unreachable"] = "pending"
|
||||
"""
|
||||
What is the current status of the request:
|
||||
- pending - the request has not been received yet, or it has been received but it's still being processed.
|
||||
- success - the request has been received and executed successfully.
|
||||
- failure - the request has been received and attempted, but execution failed.
|
||||
- unreachable - the request could not reach it's intended target, either because it doesn't exist or the target
|
||||
is off.
|
||||
"""
|
||||
|
||||
data: Dict = {}
|
||||
"""Catch-all place to provide any additional data that was generated as a response to the request."""
|
||||
# TODO: currently, status and data have default values, because I don't want to interrupt existing functionality too
|
||||
# much. However, in the future we might consider making them mandatory.
|
||||
|
||||
@classmethod
|
||||
@validate_call
|
||||
def from_bool(cls, status_bool: StrictBool) -> RequestResponse:
|
||||
"""
|
||||
Construct a basic request response from a boolean.
|
||||
|
||||
True maps to a success status. False maps to a failure status.
|
||||
|
||||
:param status_bool: Whether to create a successful response
|
||||
:type status_bool: bool
|
||||
"""
|
||||
if status_bool is True:
|
||||
return cls(status="success", data={})
|
||||
elif status_bool is False:
|
||||
return cls(status="failure", data={})
|
||||
@@ -426,13 +426,13 @@
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info[0]\n",
|
||||
" red_action = red_info['action']\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
|
||||
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
" return red_str\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -492,7 +492,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent."
|
||||
"Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) increased from 0 to 1, but only right after the red attack, so we probably cannot see it now."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -510,9 +510,9 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
|
||||
"print(f\"Blue reward:{reward}\" )"
|
||||
]
|
||||
},
|
||||
@@ -535,7 +535,7 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
|
||||
"print(f\"Blue reward:{reward:.2f}\" )"
|
||||
@@ -557,17 +557,17 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(50) # Block client 1\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(51) # Block client 2\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"for step in range(30):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -606,20 +606,35 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's unblock client 2\n",
|
||||
" env.step(58) # remove ACL rule 6\n",
|
||||
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" env.step(57) # remove ACL rule 5\n",
|
||||
"else:\n",
|
||||
" print(\"something went wrong, neither client has NMNEs\")"
|
||||
"env.step(58) # Remove the ACL rule that blocks client 1\n",
|
||||
"env.step(57) # Remove the ACL rule that blocks client 2\n",
|
||||
"\n",
|
||||
"tries = 0\n",
|
||||
"while True:\n",
|
||||
" tries += 1\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0)\n",
|
||||
"\n",
|
||||
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
|
||||
" break\n",
|
||||
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 2 has NMNEs, so let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
|
||||
" break\n",
|
||||
" if tries>100:\n",
|
||||
" print(\"Error: NMNE never increased\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"env.step(13) # Patch the database\n",
|
||||
"..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks."
|
||||
"Now, the reward will eventually increase to 0.9, even after red agent attempts subsequent attacks."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -628,9 +643,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for step in range(30):\n",
|
||||
"\n",
|
||||
"for step in range(40):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -648,6 +664,13 @@
|
||||
"source": [
|
||||
"env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -93,7 +93,7 @@ class PrimaiteIO:
|
||||
{
|
||||
"episode": episode,
|
||||
"timestep": timestep,
|
||||
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
|
||||
"agent_actions": agent_actions,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
# flake8: noqa
|
||||
"""Core of the PrimAITE Simulator."""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
|
||||
class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
@@ -21,15 +24,15 @@ class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, request: List[str], context: Dict) -> bool:
|
||||
"""Use the request and context paramters to decide whether the request should be permitted."""
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Use the request and context parameters to decide whether the request should be permitted."""
|
||||
pass
|
||||
|
||||
|
||||
class AllowAllValidator(RequestPermissionValidator):
|
||||
"""Always allows the request."""
|
||||
|
||||
def __call__(self, request: List[str], context: Dict) -> bool:
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Always allow the request."""
|
||||
return True
|
||||
|
||||
@@ -42,7 +45,7 @@ class RequestType(BaseModel):
|
||||
the request can be performed or not.
|
||||
"""
|
||||
|
||||
func: Callable[[List[str], Dict], None]
|
||||
func: Callable[[RequestFormat, Dict], RequestResponse]
|
||||
"""
|
||||
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
|
||||
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
|
||||
@@ -71,7 +74,7 @@ class RequestManager(BaseModel):
|
||||
request_types: Dict[str, RequestType] = {}
|
||||
"""maps request name to an RequestType object."""
|
||||
|
||||
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""
|
||||
Process an request request.
|
||||
|
||||
@@ -84,23 +87,23 @@ class RequestManager(BaseModel):
|
||||
:raises RuntimeError: If the request parameter does not have a valid request name as the first item.
|
||||
"""
|
||||
request_key = request[0]
|
||||
request_options = request[1:]
|
||||
|
||||
if request_key not in self.request_types:
|
||||
msg = (
|
||||
f"Request {request} could not be processed because {request_key} is not a valid request name",
|
||||
"within this RequestManager",
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
_LOGGER.debug(msg)
|
||||
return RequestResponse(status="unreachable", data={"reason": msg})
|
||||
|
||||
request_type = self.request_types[request_key]
|
||||
request_options = request[1:]
|
||||
|
||||
if not request_type.validator(request_options, context):
|
||||
_LOGGER.debug(f"Request {request} was denied due to insufficient permissions")
|
||||
return
|
||||
return RequestResponse(status="failure", data={"reason": "request validation failed"})
|
||||
|
||||
request_type.func(request_options, context)
|
||||
return request_type.func(request_options, context)
|
||||
|
||||
def add_request(self, name: str, request_type: RequestType) -> None:
|
||||
"""
|
||||
@@ -202,7 +205,8 @@ class SimComponent(BaseModel):
|
||||
}
|
||||
return state
|
||||
|
||||
def apply_request(self, request: List[str], context: Dict = {}) -> None:
|
||||
@validate_call
|
||||
def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse:
|
||||
"""
|
||||
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
|
||||
|
||||
@@ -222,7 +226,7 @@ class SimComponent(BaseModel):
|
||||
"""
|
||||
if self._request_manager is None:
|
||||
return
|
||||
self._request_manager(request, context)
|
||||
return self._request_manager(request, context)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -80,6 +80,11 @@ class DomainController(SimComponent):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
# Action 'account' matches requests like:
|
||||
# ['account', '<account-uuid>', *account_action]
|
||||
@@ -87,6 +92,7 @@ class DomainController(SimComponent):
|
||||
"account",
|
||||
RequestType(
|
||||
func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context),
|
||||
# TODO: not sure what should get returned here, revisit
|
||||
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -114,16 +114,17 @@ class File(FileSystemItemABC):
|
||||
state["num_access"] = self.num_access
|
||||
return state
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Updates the visible statuses of the file."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
|
||||
self.visible_health_status = self.health_status
|
||||
return True
|
||||
|
||||
def reveal_to_red(self) -> None:
|
||||
"""Reveals the folder/file to the red agent."""
|
||||
@@ -132,7 +133,7 @@ class File(FileSystemItemABC):
|
||||
return
|
||||
self.revealed_to_red = True
|
||||
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Check if the file has been changed.
|
||||
|
||||
@@ -142,7 +143,7 @@ class File(FileSystemItemABC):
|
||||
"""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to check hash of deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
current_hash = None
|
||||
|
||||
# if file is real, read the file contents
|
||||
@@ -164,12 +165,13 @@ class File(FileSystemItemABC):
|
||||
# if the previous hash and current hash do not match, mark file as corrupted
|
||||
if self.previous_hash is not current_hash:
|
||||
self.corrupt()
|
||||
return True
|
||||
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to repair deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# set file status to good if corrupt
|
||||
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
@@ -178,12 +180,13 @@ class File(FileSystemItemABC):
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to corrupt deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# set file status to good if corrupt
|
||||
if self.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
@@ -192,12 +195,13 @@ class File(FileSystemItemABC):
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""Determines if the file needs to be repaired or unmarked as deleted."""
|
||||
if self.deleted:
|
||||
self.deleted = False
|
||||
return
|
||||
return True
|
||||
|
||||
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.health_status = FileSystemItemHealthStatus.GOOD
|
||||
@@ -205,13 +209,15 @@ class File(FileSystemItemABC):
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> bool:
|
||||
"""Marks the file as deleted."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
self.deleted = True
|
||||
self.sys_log.info(f"File deleted {self.folder_name}/{self.name}")
|
||||
return True
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
@@ -39,18 +40,27 @@ class FileSystem(SimComponent):
|
||||
self.create_folder("root")
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
self._delete_manager = RequestManager()
|
||||
self._delete_manager.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: self.delete_file(folder_name=request[0], file_name=request[1])
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.delete_file(folder_name=request[0], file_name=request[1])
|
||||
)
|
||||
),
|
||||
)
|
||||
self._delete_manager.add_request(
|
||||
name="folder",
|
||||
request_type=RequestType(func=lambda request, context: self.delete_folder(folder_name=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
name="delete",
|
||||
@@ -61,12 +71,16 @@ class FileSystem(SimComponent):
|
||||
self._restore_manager.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: self.restore_file(folder_name=request[0], file_name=request[1])
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.restore_file(folder_name=request[0], file_name=request[1])
|
||||
)
|
||||
),
|
||||
)
|
||||
self._restore_manager.add_request(
|
||||
name="folder",
|
||||
request_type=RequestType(func=lambda request, context: self.restore_folder(folder_name=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.restore_folder(folder_name=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
name="restore",
|
||||
@@ -142,7 +156,7 @@ class FileSystem(SimComponent):
|
||||
)
|
||||
return folder
|
||||
|
||||
def delete_folder(self, folder_name: str):
|
||||
def delete_folder(self, folder_name: str) -> bool:
|
||||
"""
|
||||
Deletes a folder, removes it from the folders list and removes any child folders and files.
|
||||
|
||||
@@ -150,24 +164,26 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
return False
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
else:
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
|
||||
return False
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str):
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
return True
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str) -> None:
|
||||
"""
|
||||
Deletes a folder via its uuid.
|
||||
|
||||
@@ -303,7 +319,7 @@ class FileSystem(SimComponent):
|
||||
|
||||
return file
|
||||
|
||||
def delete_file(self, folder_name: str, file_name: str):
|
||||
def delete_file(self, folder_name: str, file_name: str) -> bool:
|
||||
"""
|
||||
Delete a file by its name from a specific folder.
|
||||
|
||||
@@ -317,8 +333,10 @@ class FileSystem(SimComponent):
|
||||
# increment file creation
|
||||
self.num_file_deletions += 1
|
||||
folder.remove_file(file)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_file_by_id(self, folder_uuid: str, file_uuid: str):
|
||||
def delete_file_by_id(self, folder_uuid: str, file_uuid: str) -> None:
|
||||
"""
|
||||
Deletes a file via its uuid.
|
||||
|
||||
@@ -335,7 +353,7 @@ class FileSystem(SimComponent):
|
||||
else:
|
||||
self.sys_log.error(f"Unable to delete file that does not exist. (id: {file_uuid})")
|
||||
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str) -> None:
|
||||
"""
|
||||
Move a file from one folder to another.
|
||||
|
||||
@@ -423,7 +441,7 @@ class FileSystem(SimComponent):
|
||||
# Agent actions
|
||||
###############################################################
|
||||
|
||||
def scan(self, instant_scan: bool = False):
|
||||
def scan(self, instant_scan: bool = False) -> None:
|
||||
"""
|
||||
Scan all the folders (and child files) in the file system.
|
||||
|
||||
@@ -432,7 +450,7 @@ class FileSystem(SimComponent):
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].scan(instant_scan=instant_scan)
|
||||
|
||||
def reveal_to_red(self, instant_scan: bool = False):
|
||||
def reveal_to_red(self, instant_scan: bool = False) -> None:
|
||||
"""
|
||||
Reveals all the folders (and child files) in the file system to the red agent.
|
||||
|
||||
@@ -441,7 +459,7 @@ class FileSystem(SimComponent):
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].reveal_to_red(instant_scan=instant_scan)
|
||||
|
||||
def restore_folder(self, folder_name: str):
|
||||
def restore_folder(self, folder_name: str) -> bool:
|
||||
"""
|
||||
Restore a folder.
|
||||
|
||||
@@ -454,13 +472,14 @@ class FileSystem(SimComponent):
|
||||
|
||||
if folder is None:
|
||||
self.sys_log.error(f"Unable to restore folder {folder_name}. Folder is not in deleted folder list.")
|
||||
return
|
||||
return False
|
||||
|
||||
self.deleted_folders.pop(folder.uuid, None)
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
return True
|
||||
|
||||
def restore_file(self, folder_name: str, file_name: str):
|
||||
def restore_file(self, folder_name: str, file_name: str) -> bool:
|
||||
"""
|
||||
Restore a file.
|
||||
|
||||
@@ -473,12 +492,15 @@ class FileSystem(SimComponent):
|
||||
:type: file_name: str
|
||||
"""
|
||||
folder = self.get_folder(folder_name=folder_name)
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.")
|
||||
return False
|
||||
|
||||
if folder:
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
|
||||
if file is None:
|
||||
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
|
||||
return
|
||||
if not file:
|
||||
msg = f"Unable to restore file {file_name}. File was not found."
|
||||
self.sys_log.error(msg)
|
||||
return False
|
||||
|
||||
folder.restore_file(file_name=file_name)
|
||||
return folder.restore_file(file_name=file_name)
|
||||
|
||||
@@ -6,6 +6,7 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
@@ -100,14 +101,33 @@ class FileSystemItemABC(SimComponent):
|
||||
return state
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash()))
|
||||
rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair()))
|
||||
rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore()))
|
||||
rm.add_request(
|
||||
name="scan", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
|
||||
)
|
||||
rm.add_request(
|
||||
name="checkhash",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.check_hash())),
|
||||
)
|
||||
rm.add_request(
|
||||
name="repair",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.repair())),
|
||||
)
|
||||
rm.add_request(
|
||||
name="restore",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.restore())),
|
||||
)
|
||||
|
||||
rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt()))
|
||||
rm.add_request(
|
||||
name="corrupt",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.corrupt())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -124,9 +144,9 @@ class FileSystemItemABC(SimComponent):
|
||||
return convert_size(self.size)
|
||||
|
||||
@abstractmethod
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Scan the folder/file - updates the visible_health_status."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def reveal_to_red(self) -> None:
|
||||
@@ -134,7 +154,7 @@ class FileSystemItemABC(SimComponent):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Checks the has of the file to detect any changes.
|
||||
|
||||
@@ -142,30 +162,30 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""
|
||||
Repair the FileSystemItem.
|
||||
|
||||
True if successfully repaired. False otherwise.
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""
|
||||
Corrupt the FileSystemItem.
|
||||
|
||||
True if successfully corrupted. False otherwise.
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""Restore the file/folder to the state before it got ruined."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
@@ -50,10 +51,17 @@ class Folder(FileSystemItemABC):
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="delete",
|
||||
request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.remove_file_by_name(file_name=request[0]))
|
||||
),
|
||||
)
|
||||
self._file_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
@@ -250,6 +258,21 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file_by_id(file_uuid=file_uuid)
|
||||
self.remove_file(file=file)
|
||||
|
||||
def remove_file_by_name(self, file_name: str) -> bool:
|
||||
"""
|
||||
Remove a file using its name.
|
||||
|
||||
:param file_name: filename
|
||||
:type file_name: str
|
||||
:return: Whether it was successfully removed.
|
||||
:rtype: bool
|
||||
"""
|
||||
for f in self.files.values():
|
||||
if f.name == file_name:
|
||||
self.remove_file(f)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_all_files(self):
|
||||
"""Removes all the files in the folder."""
|
||||
for file_id in self.files:
|
||||
@@ -259,7 +282,7 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
self.files = {}
|
||||
|
||||
def restore_file(self, file_name: str):
|
||||
def restore_file(self, file_name: str) -> bool:
|
||||
"""
|
||||
Restores a file.
|
||||
|
||||
@@ -269,13 +292,14 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file(file_name=file_name, include_deleted=True)
|
||||
if not file:
|
||||
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
|
||||
return
|
||||
return False
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file.uuid)
|
||||
return True
|
||||
|
||||
def quarantine(self):
|
||||
"""Quarantines the File System Folder."""
|
||||
@@ -289,7 +313,7 @@ class Folder(FileSystemItemABC):
|
||||
"""Returns true if the folder is being quarantined."""
|
||||
pass
|
||||
|
||||
def scan(self, instant_scan: bool = False) -> None:
|
||||
def scan(self, instant_scan: bool = False) -> bool:
|
||||
"""
|
||||
Update Folder visible status.
|
||||
|
||||
@@ -297,7 +321,7 @@ class Folder(FileSystemItemABC):
|
||||
"""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to scan deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
if instant_scan:
|
||||
for file_id in self.files:
|
||||
@@ -305,7 +329,7 @@ class Folder(FileSystemItemABC):
|
||||
file.scan()
|
||||
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
return
|
||||
return True
|
||||
|
||||
if self.scan_countdown <= 0:
|
||||
# scan one file per timestep
|
||||
@@ -314,6 +338,7 @@ class Folder(FileSystemItemABC):
|
||||
else:
|
||||
# scan already in progress
|
||||
self.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def reveal_to_red(self, instant_scan: bool = False):
|
||||
"""
|
||||
@@ -340,7 +365,7 @@ class Folder(FileSystemItemABC):
|
||||
# scan already in progress
|
||||
self.sys_log.info(f"Red Agent Scan is already in progress {self.name} (id: {self.uuid})")
|
||||
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Runs a :func:`check_hash` on all files in the folder.
|
||||
|
||||
@@ -353,7 +378,7 @@ class Folder(FileSystemItemABC):
|
||||
"""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to check hash of deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# iterate through the files and run a check hash
|
||||
no_corrupted_files = True
|
||||
@@ -369,12 +394,13 @@ class Folder(FileSystemItemABC):
|
||||
self.corrupt()
|
||||
|
||||
self.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to repair deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# iterate through the files in the folder
|
||||
for file_id in self.files:
|
||||
@@ -388,8 +414,9 @@ class Folder(FileSystemItemABC):
|
||||
self.health_status = FileSystemItemHealthStatus.GOOD
|
||||
|
||||
self.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""
|
||||
If a Folder is corrupted, run a repair on the folder and its child files.
|
||||
|
||||
@@ -405,12 +432,13 @@ class Folder(FileSystemItemABC):
|
||||
else:
|
||||
# scan already in progress
|
||||
self.sys_log.info(f"Folder restoration already in progress {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to corrupt deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
# iterate through the files in the folder
|
||||
for file_id in self.files:
|
||||
@@ -421,11 +449,13 @@ class Folder(FileSystemItemABC):
|
||||
self.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
self.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})")
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> bool:
|
||||
"""Marks the file as deleted. Prevents agent actions from occuring."""
|
||||
if self.deleted:
|
||||
self.sys_log.error(f"Unable to delete an already deleted folder {self.name}")
|
||||
return
|
||||
return False
|
||||
|
||||
self.deleted = True
|
||||
return True
|
||||
|
||||
@@ -61,6 +61,11 @@ class Network(SimComponent):
|
||||
software.run()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
self._node_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
|
||||
@@ -12,8 +12,9 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -113,10 +114,15 @@ class NetworkInterface(SimComponent, ABC):
|
||||
self.enable()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -140,14 +146,16 @@ class NetworkInterface(SimComponent, ABC):
|
||||
return state
|
||||
|
||||
@abstractmethod
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return False
|
||||
|
||||
def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None:
|
||||
"""
|
||||
@@ -257,10 +265,9 @@ class NetworkInterface(SimComponent, ABC):
|
||||
"""
|
||||
Apply a timestep evolution to this component.
|
||||
|
||||
This just clears the nmne count back to 0.tests/integration_tests/network/test_capture_nmne.py
|
||||
This just clears the nmne count back to 0.
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
self.nmne.clear()
|
||||
|
||||
|
||||
class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
@@ -284,24 +291,24 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
_connected_link: Optional[Link] = None
|
||||
"The network link to which the network interface is connected."
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Attempt to enable the network interface."""
|
||||
if self.enabled:
|
||||
return
|
||||
return True
|
||||
|
||||
if not self._connected_node:
|
||||
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
|
||||
return
|
||||
return False
|
||||
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.info(
|
||||
f"Interface {self} cannot be enabled as the connected Node is not powered on"
|
||||
)
|
||||
return
|
||||
return False
|
||||
|
||||
if not self._connected_link:
|
||||
self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.")
|
||||
return
|
||||
return False
|
||||
|
||||
self.enabled = True
|
||||
self._connected_node.sys_log.info(f"Network Interface {self} enabled")
|
||||
@@ -310,11 +317,12 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
)
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_up()
|
||||
return True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the network interface."""
|
||||
if not self.enabled:
|
||||
return
|
||||
return True
|
||||
self.enabled = False
|
||||
if self._connected_node:
|
||||
self._connected_node.sys_log.info(f"Network Interface {self} disabled")
|
||||
@@ -322,6 +330,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
_LOGGER.debug(f"Interface {self} disabled")
|
||||
if self._connected_link:
|
||||
self._connected_link.endpoint_down()
|
||||
return True
|
||||
|
||||
def connect_link(self, link: Link):
|
||||
"""
|
||||
@@ -496,7 +505,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
|
||||
return state
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""
|
||||
Enables this wired network interface and attempts to send a "hello" message to the default gateway.
|
||||
|
||||
@@ -512,8 +521,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
try:
|
||||
pass
|
||||
self._connected_node.default_gateway_hello()
|
||||
return True
|
||||
except AttributeError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
@@ -765,35 +776,74 @@ class Node(SimComponent):
|
||||
self.sys_log.current_episode = episode
|
||||
self.sys_log.setup_logger()
|
||||
|
||||
class _NodeIsOnValidator(RequestPermissionValidator):
|
||||
"""
|
||||
When requests come in, this validator will only let them through if the node is on.
|
||||
|
||||
This is useful because no actions should be being resolved if the node is off.
|
||||
"""
|
||||
|
||||
node: Node
|
||||
"""Save a reference to the node instance."""
|
||||
|
||||
def __call__(self, request: RequestFormat, context: Dict) -> bool:
|
||||
"""Return whether the node is on or off."""
|
||||
return self.node.operating_state == NodeOperatingState.ON
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
|
||||
# need a better name and better documentation.
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
_node_is_on = Node._NodeIsOnValidator(node=self)
|
||||
|
||||
rm = super()._init_request_manager()
|
||||
# since there are potentially many services, create an request manager that can map service name
|
||||
self._service_request_manager = RequestManager()
|
||||
rm.add_request("service", RequestType(func=self._service_request_manager))
|
||||
rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on))
|
||||
self._nic_request_manager = RequestManager()
|
||||
rm.add_request("network_interface", RequestType(func=self._nic_request_manager))
|
||||
rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on))
|
||||
|
||||
rm.add_request("file_system", RequestType(func=self.file_system._request_manager))
|
||||
rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on))
|
||||
|
||||
# currently we don't have any applications nor processes, so these will be empty
|
||||
self._process_request_manager = RequestManager()
|
||||
rm.add_request("process", RequestType(func=self._process_request_manager))
|
||||
rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on))
|
||||
self._application_request_manager = RequestManager()
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager))
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on))
|
||||
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.reveal_to_red()))
|
||||
rm.add_request(
|
||||
"scan",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
|
||||
rm.add_request("reset", RequestType(func=lambda request, context: self.reset())) # TODO implement node reset
|
||||
rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
|
||||
rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
|
||||
rm.add_request(
|
||||
"shutdown",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
|
||||
),
|
||||
)
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
|
||||
rm.add_request(
|
||||
"reset",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),
|
||||
) # TODO implement node reset
|
||||
rm.add_request(
|
||||
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
|
||||
) # TODO implement logon request
|
||||
rm.add_request(
|
||||
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on)
|
||||
) # TODO implement logoff request
|
||||
|
||||
self._os_request_manager = RequestManager()
|
||||
self._os_request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager))
|
||||
self._os_request_manager.add_request(
|
||||
"scan",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on),
|
||||
)
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -973,7 +1023,7 @@ class Node(SimComponent):
|
||||
|
||||
self.file_system.apply_timestep(timestep=timestep)
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""
|
||||
Scan the node and all the items within it.
|
||||
|
||||
@@ -987,8 +1037,9 @@ class Node(SimComponent):
|
||||
to the red agent.
|
||||
"""
|
||||
self.node_scan_countdown = self.node_scan_duration
|
||||
return True
|
||||
|
||||
def reveal_to_red(self) -> None:
|
||||
def reveal_to_red(self) -> bool:
|
||||
"""
|
||||
Reveals the node and all the items within it to the red agent.
|
||||
|
||||
@@ -1002,34 +1053,40 @@ class Node(SimComponent):
|
||||
`revealed_to_red` to `True`.
|
||||
"""
|
||||
self.red_scan_countdown = self.node_scan_duration
|
||||
return True
|
||||
|
||||
def power_on(self):
|
||||
def power_on(self) -> bool:
|
||||
"""Power on the Node, enabling its NICs if it is in the OFF state."""
|
||||
if self.operating_state == NodeOperatingState.OFF:
|
||||
self.operating_state = NodeOperatingState.BOOTING
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
|
||||
if self.start_up_duration <= 0:
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self._start_up_actions()
|
||||
self.sys_log.info("Power on")
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.enable()
|
||||
return True
|
||||
if self.operating_state == NodeOperatingState.OFF:
|
||||
self.operating_state = NodeOperatingState.BOOTING
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
return True
|
||||
|
||||
def power_off(self):
|
||||
return False
|
||||
|
||||
def power_off(self) -> bool:
|
||||
"""Power off the Node, disabling its NICs if it is in the ON state."""
|
||||
if self.shut_down_duration <= 0:
|
||||
self._shut_down_actions()
|
||||
self.operating_state = NodeOperatingState.OFF
|
||||
self.sys_log.info("Power off")
|
||||
return True
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.disable()
|
||||
self.operating_state = NodeOperatingState.SHUTTING_DOWN
|
||||
self.shut_down_countdown = self.shut_down_duration
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.shut_down_duration <= 0:
|
||||
self._shut_down_actions()
|
||||
self.operating_state = NodeOperatingState.OFF
|
||||
self.sys_log.info("Power off")
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> bool:
|
||||
"""
|
||||
Resets the node.
|
||||
|
||||
@@ -1040,6 +1097,8 @@ class Node(SimComponent):
|
||||
self.is_resetting = True
|
||||
self.sys_log.info("Resetting")
|
||||
self.power_off()
|
||||
return True
|
||||
return False
|
||||
|
||||
def connect_nic(self, network_interface: NetworkInterface, port_name: Optional[str] = None):
|
||||
"""
|
||||
|
||||
@@ -51,13 +51,15 @@ class WirelessAccessPoint(IPWirelessNetworkInterface):
|
||||
|
||||
return state
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
|
||||
@@ -48,13 +48,15 @@ class WirelessNIC(IPWirelessNetworkInterface):
|
||||
|
||||
return state
|
||||
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return True
|
||||
|
||||
def send_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -293,6 +294,11 @@ class AccessControlList(SimComponent):
|
||||
self._acl = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
# TODO: Add src and dst wildcard masks as positional args in this request.
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
@@ -308,19 +314,24 @@ class AccessControlList(SimComponent):
|
||||
rm.add_request(
|
||||
"add_rule",
|
||||
RequestType(
|
||||
func=lambda request, context: self.add_rule(
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.add_rule(
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
|
||||
rm.add_request(
|
||||
"remove_rule",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.remove_rule(int(request[0])))),
|
||||
)
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
@@ -366,7 +377,7 @@ class AccessControlList(SimComponent):
|
||||
src_port: Optional[Port] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
position: int = 0,
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""
|
||||
Adds a new ACL rule to control network traffic based on specified criteria.
|
||||
|
||||
@@ -423,10 +434,12 @@ class AccessControlList(SimComponent):
|
||||
src_port=src_port,
|
||||
dst_port=dst_port,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.")
|
||||
return False
|
||||
|
||||
def remove_rule(self, position: int) -> None:
|
||||
def remove_rule(self, position: int) -> bool:
|
||||
"""
|
||||
Remove an ACL rule from a specific position.
|
||||
|
||||
@@ -437,8 +450,10 @@ class AccessControlList(SimComponent):
|
||||
rule = self._acl[position] # noqa
|
||||
self._acl[position] = None
|
||||
del rule
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.")
|
||||
return False
|
||||
|
||||
def is_permitted(self, frame: Frame) -> Tuple[bool, ACLRule]:
|
||||
"""Check if a packet with the given properties is permitted through the ACL."""
|
||||
@@ -1082,6 +1097,11 @@ class Router(NetworkNode):
|
||||
super().setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("acl", RequestType(func=self.acl._request_manager))
|
||||
return rm
|
||||
|
||||
@@ -146,9 +146,12 @@ def arcd_uc2_network() -> Network:
|
||||
)
|
||||
client_1.power_on()
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
db_client_1 = client_1.software_manager.install(DatabaseClient)
|
||||
db_client_1 = client_1.software_manager.software.get("DatabaseClient")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
db_client_1.run()
|
||||
web_browser_1 = client_1.software_manager.software.get("WebBrowser")
|
||||
web_browser_1.target_url = "http://arcd.com/users/"
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
db_manipulation_bot.configure(
|
||||
@@ -170,9 +173,10 @@ def arcd_uc2_network() -> Network:
|
||||
client_2.power_on()
|
||||
client_2.software_manager.install(DatabaseClient)
|
||||
db_client_2 = client_2.software_manager.software.get("DatabaseClient")
|
||||
db_client_2.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
db_client_2.run()
|
||||
web_browser = client_2.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
web_browser_2 = client_2.software_manager.software.get("WebBrowser")
|
||||
web_browser_2.target_url = "http://arcd.com/users/"
|
||||
network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2])
|
||||
|
||||
# Domain Controller
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Dict
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.domain.controller import DomainController
|
||||
from primaite.simulator.network.container import Network
|
||||
@@ -26,12 +27,18 @@ class Simulation(SimComponent):
|
||||
self.network.setup_for_episode(episode=episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
# pass through network requests to the network objects
|
||||
rm.add_request("network", RequestType(func=self.network._request_manager))
|
||||
# pass through domain requests to the domain object
|
||||
rm.add_request("domain", RequestType(func=self.domain._request_manager))
|
||||
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
|
||||
# if 'do_nothing' is requested, just return a success
|
||||
rm.add_request("do_nothing", RequestType(func=lambda request, context: RequestResponse(status="success")))
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -36,8 +37,13 @@ class DatabaseClient(Application):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: self.execute()))
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
|
||||
return rm
|
||||
|
||||
def execute(self) -> bool:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -74,9 +75,17 @@ class DataManipulationBot(Application):
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -179,23 +188,22 @@ class DataManipulationBot(Application):
|
||||
"""
|
||||
super().run()
|
||||
|
||||
def attack(self):
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
self._application_loop()
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop of the bot, handling the attack process.
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
self.num_executions += 1
|
||||
self.num_executions += 1
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self._logon()
|
||||
@@ -207,8 +215,12 @@ class DataManipulationBot(Application):
|
||||
DataManipulationAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = DataManipulationAttackStage.NOT_STARTED
|
||||
|
||||
return True
|
||||
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -57,9 +58,17 @@ class DoSBot(DatabaseClient):
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -97,26 +106,26 @@ class DoSBot(DatabaseClient):
|
||||
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
|
||||
)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> bool:
|
||||
"""Run the Denial of Service Bot."""
|
||||
super().run()
|
||||
self._application_loop()
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
return False
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return
|
||||
return True
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
@@ -126,6 +135,7 @@ class DoSBot(DatabaseClient):
|
||||
self.attack_stage = DoSAttackStage.NOT_STARTED
|
||||
else:
|
||||
self.attack_stage = DoSAttackStage.COMPLETED
|
||||
return True
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from urllib.parse import urlparse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
@@ -50,9 +51,17 @@ class WebBrowser(Application):
|
||||
self.run()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa
|
||||
name="execute",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.get_webpage())
|
||||
), # noqa
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
@@ -79,15 +80,20 @@ class Service(IOSoftware):
|
||||
return super().receive(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("stop", RequestType(func=lambda request, context: self.stop()))
|
||||
rm.add_request("start", RequestType(func=lambda request, context: self.start()))
|
||||
rm.add_request("pause", RequestType(func=lambda request, context: self.pause()))
|
||||
rm.add_request("resume", RequestType(func=lambda request, context: self.resume()))
|
||||
rm.add_request("restart", RequestType(func=lambda request, context: self.restart()))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop())))
|
||||
rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start())))
|
||||
rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause())))
|
||||
rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume())))
|
||||
rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart())))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
@@ -106,17 +112,19 @@ class Service(IOSoftware):
|
||||
state["health_state_visible"] = self.health_state_visible.value
|
||||
return state
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self) -> bool:
|
||||
"""Stop the service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
def start(self, **kwargs) -> bool:
|
||||
"""Start the service."""
|
||||
# cant start the service if the node it is on is off
|
||||
if not super()._can_perform_action():
|
||||
return
|
||||
return False
|
||||
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
@@ -124,36 +132,47 @@ class Service(IOSoftware):
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
return True
|
||||
return False
|
||||
|
||||
def pause(self) -> None:
|
||||
def pause(self) -> bool:
|
||||
"""Pause the service."""
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
return True
|
||||
return False
|
||||
|
||||
def resume(self) -> None:
|
||||
def resume(self) -> bool:
|
||||
"""Resume paused service."""
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
return True
|
||||
return False
|
||||
|
||||
def restart(self) -> None:
|
||||
def restart(self) -> bool:
|
||||
"""Restart running service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
self.restart_countdown = self.restart_duration
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self) -> None:
|
||||
def disable(self) -> bool:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
return True
|
||||
|
||||
def enable(self) -> None:
|
||||
def enable(self) -> bool:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -101,20 +102,27 @@ class Software(SimComponent):
|
||||
"Current number of ticks left to patch the software."
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
Initialise the request manager.
|
||||
|
||||
More information in user guide and docstring for SimComponent._init_request_manager.
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
"compromise",
|
||||
RequestType(
|
||||
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
),
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"patch",
|
||||
RequestType(
|
||||
func=lambda request, context: self.patch(),
|
||||
func=lambda request, context: RequestResponse.from_bool(self.patch()),
|
||||
),
|
||||
)
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
return rm
|
||||
|
||||
def _get_session_details(self, session_id: str) -> Session:
|
||||
@@ -148,7 +156,7 @@ class Software(SimComponent):
|
||||
)
|
||||
return state
|
||||
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> None:
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> bool:
|
||||
"""
|
||||
Assign a new health state to this software.
|
||||
|
||||
@@ -160,6 +168,7 @@ class Software(SimComponent):
|
||||
:type health_state: SoftwareHealthState
|
||||
"""
|
||||
self.health_state_actual = health_state
|
||||
return True
|
||||
|
||||
def install(self) -> None:
|
||||
"""
|
||||
@@ -180,15 +189,18 @@ class Software(SimComponent):
|
||||
"""
|
||||
pass
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Update the observed health status to match the actual health status."""
|
||||
self.health_state_visible = self.health_state_actual
|
||||
return True
|
||||
|
||||
def patch(self) -> None:
|
||||
def patch(self) -> bool:
|
||||
"""Perform a patch on the software."""
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
|
||||
Reference in New Issue
Block a user