Add data manipulation bot action manager
This commit is contained in:
@@ -78,12 +78,12 @@ game_config:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
# - type: NODE_FILE_DELETE
|
||||
# - type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
|
||||
@@ -157,6 +157,37 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NodeApplicationAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for application actions.
|
||||
|
||||
Any action which applies to an application and uses node_id and application_id as its only two parameters can
|
||||
inherit from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, application_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id)
|
||||
if node_uuid is None or application_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "application", application_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
|
||||
"""Action which executes an application."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
|
||||
self.verb = "execute"
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
@@ -536,6 +567,7 @@ class ActionManager:
|
||||
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
|
||||
"NODE_FILE_SCAN": NodeFileScanAction,
|
||||
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
@@ -565,6 +597,7 @@ class ActionManager:
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
max_files_per_folder: int = 2, # allows calculating shape
|
||||
max_services_per_node: int = 2, # allows calculating shape
|
||||
max_applications_per_node: int = 10, # allows calculating shape
|
||||
max_nics_per_node: int = 8, # allows calculating shape
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
@@ -622,6 +655,7 @@ class ActionManager:
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_applications": max_applications_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules,
|
||||
"num_protocols": len(self.protocols),
|
||||
@@ -775,6 +809,21 @@ class ActionManager:
|
||||
service_uuids = list(node.services.keys())
|
||||
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
|
||||
|
||||
def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
|
||||
"""Get the application UUID corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param application_idx: The index of the service on the node.
|
||||
:type application_idx: int
|
||||
:return: The UUID of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
application_uuids = list(node.applications.keys())
|
||||
return application_uuids[application_idx] if len(application_uuids) > application_idx else None
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
"""Get the internet protocol corresponding to the given index.
|
||||
|
||||
|
||||
@@ -154,7 +154,17 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.next_execution_timestep = self.agent_settings.start_settings.start_step
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
@@ -166,21 +176,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# TODO: Move this to the appropriate place
|
||||
# return self.action_space.get_action(self.action_space.space.sample())
|
||||
current_timestep = self.action_space.session.step_counter
|
||||
|
||||
timestep = self.action_space.session.step_counter
|
||||
|
||||
if timestep < self.next_execution_timestep:
|
||||
if current_timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {"dummy": 0}
|
||||
|
||||
var = random.randint(-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance)
|
||||
self.next_execution_timestep = timestep + self.agent_settings.start_settings.frequency + var
|
||||
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
for bot in self.data_manipulation_bots:
|
||||
bot.execute()
|
||||
|
||||
return "DONOTHING", {"dummy": 0}
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
|
||||
@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
@@ -46,6 +47,13 @@ class DataManipulationBot(DatabaseClient):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DataManipulationBot"
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=self.execute))
|
||||
|
||||
return rm
|
||||
|
||||
def configure(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
Reference in New Issue
Block a user