Move data manipulation agent into individual file

This commit is contained in:
Jake Walker
2023-11-24 16:32:04 +00:00
parent afce6ca515
commit cbdaa6c444
4 changed files with 51 additions and 44 deletions

View File

@@ -0,0 +1,48 @@
import random
from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
class DataManipulationAgent(AbstractScriptedAgent):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
data_manipulation_bots: List["DataManipulationBot"] = []
next_execution_timestep: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.session.step_counter
if current_timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}

View File

@@ -1,5 +1,4 @@
"""Interface for agents."""
import random
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
@@ -11,7 +10,7 @@ from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
pass
class AgentStartSettings(BaseModel):
@@ -183,47 +182,6 @@ class ProxyAgent(AbstractAgent):
self.most_recent_action = action
class DataManipulationAgent(AbstractScriptedAgent):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
data_manipulation_bots: List["DataManipulationBot"] = []
next_execution_timestep: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.session.step_counter
if current_timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""

View File

@@ -11,7 +11,8 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, ProxyAgent, RandomAgent
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings