Merge remote-tracking branch 'origin/dev' into feature/2243-sync-uc2-action-to-iy
This commit is contained in:
@@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
|
||||
...
|
||||
- ref: green_agent_example
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
@@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive
|
||||
|
||||
``type``
|
||||
--------
|
||||
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``probabilistic_agent`` generate their own behaviour.
|
||||
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour.
|
||||
|
||||
Available agent types:
|
||||
|
||||
- ``probabilistic_agent``
|
||||
- ``ProbabilisticAgent``
|
||||
- ``ProxyAgent``
|
||||
- ``RedDatabaseCorruptingAgent``
|
||||
|
||||
|
||||
@@ -25,6 +25,13 @@ Usage
|
||||
- Clients connect, execute queries, and disconnect.
|
||||
- Service runs on TCP port 5432 by default.
|
||||
|
||||
**Supported queries:**
|
||||
|
||||
* ``SELECT``: As long as the database file is in a ``GOOD`` health state, the db service will respond with a 200 status code.
|
||||
* ``DELETE``: This query represents an attack, it will cause the database file to enter a ``COMPROMISED`` state, and return a status code 200.
|
||||
* ``INSERT``: If the database service is in a healthy state, this will return a 200 status, if it's not in a healthy state it will return 404.
|
||||
* ``SELECT * FROM pg_stat_activity``: This query represents something an admin would send to check the status of the database. If the database service is in a healthy state, it returns a 200 status code, otherwise a 401 status code.
|
||||
|
||||
Implementation
|
||||
==============
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
@@ -76,7 +76,7 @@ agents:
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
|
||||
@@ -35,7 +35,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
@@ -78,7 +78,7 @@ agents:
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
|
||||
@@ -812,6 +812,13 @@ class ActionManager:
|
||||
:return: The node hostname.
|
||||
:rtype: str
|
||||
"""
|
||||
if not node_idx < len(self.node_names):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx}, but its action space only"
|
||||
f"has {len(self.node_names)} nodes."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.node_names[node_idx]
|
||||
|
||||
def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
@@ -825,6 +832,13 @@ class ActionManager:
|
||||
:return: The name of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this"
|
||||
f" is out of range for its action space. Folder on each node: {self.folder_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.folder_names[node_idx][folder_idx]
|
||||
|
||||
def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
@@ -840,6 +854,17 @@ class ActionManager:
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if (
|
||||
node_idx >= len(self.file_names)
|
||||
or folder_idx >= len(self.file_names[node_idx])
|
||||
or file_idx >= len(self.file_names[node_idx][folder_idx])
|
||||
):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}"
|
||||
f" but this is out of range for its action space. Files on each node: {self.file_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.file_names[node_idx][folder_idx][file_idx]
|
||||
|
||||
def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
@@ -852,6 +877,13 @@ class ActionManager:
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this"
|
||||
f" is out of range for its action space. Services on each node: {self.service_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.service_names[node_idx][service_idx]
|
||||
|
||||
def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
|
||||
@@ -864,6 +896,13 @@ class ActionManager:
|
||||
:return: The name of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but "
|
||||
f"this is out of range for its action space. Applications on each node: {self.application_names}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.application_names[node_idx][application_idx]
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
@@ -874,6 +913,13 @@ class ActionManager:
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
if protocol_idx >= len(self.protocols):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on protocol {protocol_idx} but this"
|
||||
f" is out of range for its action space. Protocols: {self.protocols}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
@@ -885,6 +931,13 @@ class ActionManager:
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
if ip_idx >= len(self.ip_address_list):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on ip address {ip_idx} but this"
|
||||
f" is out of range for its action space. IP address list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
@@ -896,6 +949,13 @@ class ActionManager:
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
if port_idx >= len(self.ports):
|
||||
msg = (
|
||||
f"Error: agent attempted to perform an action on port {port_idx} but this"
|
||||
f" is out of range for its action space. Port list: {self.ip_address_list}"
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
@@ -26,14 +26,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
|
||||
|
||||
:param obs: _description_
|
||||
:param obs: Current observation for this agent, not used in DataManipulationAgent
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
|
||||
@@ -112,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
return self.reward_function.update(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
@@ -152,14 +152,14 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: _description_
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:param timestep: The current simulation timestep, not used in RandomAgent
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
@@ -185,14 +185,14 @@ class ProxyAgent(AbstractAgent):
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
|
||||
:type reward: float, optional
|
||||
:param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface.
|
||||
:type timestep: int
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
|
||||
@@ -270,7 +270,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
return -1.0
|
||||
elif last_connection_successful is True:
|
||||
return 1.0
|
||||
return 0
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
|
||||
@@ -70,17 +70,17 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Choose a random action from the action space.
|
||||
Sample the action space randomly.
|
||||
|
||||
The probability of each action is given by the corresponding index in ``self.probabilities``.
|
||||
|
||||
:param obs: Current observation of the simulation
|
||||
:param obs: Current observation for this agent, not used in ProbabilisticAgent
|
||||
:type obs: ObsType
|
||||
:param reward: Reward for the last step, not used for scripted agents, defaults to 0
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in CAOS format.
|
||||
:param timestep: The current simulation timestep, not used in ProbabilisticAgent
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
|
||||
|
||||
@@ -165,8 +165,7 @@ class PrimaiteGame:
|
||||
agent_actions = {}
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter)
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
agent_actions[agent.agent_name] = (action_choice, options)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
@@ -414,7 +413,7 @@ class PrimaiteGame:
|
||||
reward_function = RewardFunction.from_config(reward_function_cfg)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "probabilistic_agent":
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
settings = agent_cfg.get("agent_settings")
|
||||
new_agent = ProbabilisticAgent(
|
||||
|
||||
@@ -16,7 +16,7 @@ class IPProtocol(Enum):
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
"""Placeholder for a non-port."""
|
||||
"""Placeholder for a non-protocol."""
|
||||
TCP = "tcp"
|
||||
"""Transmission Control Protocol."""
|
||||
UDP = "udp"
|
||||
|
||||
@@ -38,9 +38,7 @@ class DataManipulationAttackStage(IntEnum):
|
||||
class DataManipulationBot(Application):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
# server_ip_address: Optional[IPv4Address] = None
|
||||
payload: Optional[str] = None
|
||||
# server_password: Optional[str] = None
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -40,7 +40,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -40,7 +40,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -65,7 +65,7 @@ game:
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -31,7 +31,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -29,7 +29,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -25,7 +25,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: probabilistic_agent
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent):
|
||||
)
|
||||
self.most_recent_action: Tuple[str, Dict]
|
||||
|
||||
def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return the agent's most recent action, formatted in CAOS format."""
|
||||
return self.most_recent_action
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
|
||||
|
||||
def test_uc2_rewards(game_and_agent):
|
||||
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_probabilistic_agent():
|
||||
node_application_execute_count = 0
|
||||
node_file_delete_count = 0
|
||||
for _ in range(N_TRIALS):
|
||||
a = pa.get_action(0, timestep=0)
|
||||
a = pa.get_action(0)
|
||||
if a == ("DONOTHING", {}):
|
||||
do_nothing_count += 1
|
||||
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
|
||||
|
||||
Reference in New Issue
Block a user