Add data manipulation bot action manager

This commit is contained in:
Jake Walker
2023-11-23 16:06:19 +00:00
parent c93705867f
commit 5f1a5af1b4
5 changed files with 76 additions and 18 deletions

View File

@@ -78,12 +78,12 @@ game_config:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
- type: NODE_APPLICATION_EXECUTE
# - type: NODE_FILE_DELETE
# - type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:

View File

@@ -157,6 +157,37 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
self.verb = "enable"
class NodeApplicationAbstractAction(AbstractAction):
"""
Base class for application actions.
Any action which applies to an application and uses node_id and application_id as its only two parameters can
inherit from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str
def form_request(self, node_id: int, application_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id)
if node_uuid is None or application_uuid is None:
return ["do_nothing"]
return ["network", "node", node_uuid, "application", application_uuid, self.verb]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
"""Action which executes an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb = "execute"
class NodeFolderAbstractAction(AbstractAction):
"""
Base class for folder actions.
@@ -536,6 +567,7 @@ class ActionManager:
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
"NODE_FILE_SCAN": NodeFileScanAction,
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
"NODE_FILE_DELETE": NodeFileDeleteAction,
@@ -565,6 +597,7 @@ class ActionManager:
max_folders_per_node: int = 2, # allows calculating shape
max_files_per_folder: int = 2, # allows calculating shape
max_services_per_node: int = 2, # allows calculating shape
max_applications_per_node: int = 10, # allows calculating shape
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
@@ -622,6 +655,7 @@ class ActionManager:
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_applications": max_applications_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols),
@@ -775,6 +809,21 @@ class ActionManager:
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
"""Get the application UUID corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param application_idx: The index of the service on the node.
:type application_idx: int
:return: The UUID of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
application_uuids = list(node.applications.keys())
return application_uuids[application_idx] if len(application_uuids) > application_idx else None
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
"""Get the internet protocol corresponding to the given index.

View File

@@ -154,7 +154,17 @@ class DataManipulationAgent(AbstractScriptedAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.next_execution_timestep = self.agent_settings.start_settings.start_step
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
@@ -166,21 +176,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
# TODO: Move this to the appropriate place
# return self.action_space.get_action(self.action_space.space.sample())
current_timestep = self.action_space.session.step_counter
timestep = self.action_space.session.step_counter
if timestep < self.next_execution_timestep:
if current_timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
var = random.randint(-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance)
self.next_execution_timestep = timestep + self.agent_settings.start_settings.frequency + var
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
for bot in self.data_manipulation_bots:
bot.execute()
return "DONOTHING", {"dummy": 0}
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
class AbstractGATEAgent(AbstractAgent):

View File

@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
from typing import Optional
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -46,6 +47,13 @@ class DataManipulationBot(DatabaseClient):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=self.execute))
return rm
def configure(
self,
server_ip_address: IPv4Address,

View File

@@ -1,5 +1,3 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node