Merged PR 304: Reward sharing

## Summary
* add ability for agents to share rewards - the calculated reward value for one agent can be used as a component of another agent's reward.
* Update UC2 configs to use the reward sharing functionality - green agents have a reward based on their two actions. Blue agent reward adds green agent rewards to itself.
* Add agent action history - This allows the rewards to react to agent actions.
* Make action logging use the new agent history.
* Make the webpage and database reward components treat failed requests the same as if the webpage was unavailable / database was unreachable.
* reorder the PrimaiteGame step to be the same as the Gymnasium env step
* update uc2 notebook accordingly.

## Test process
Tested with ad-hoc notebooks and debugging tool to verify correct data is being used. Tested notebooks run properly and pytests pass. Added unit and integration tests.

## Checklist
- [ ] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2372
This commit is contained in:
Marek Wolan
2024-03-15 10:42:12 +00:00
16 changed files with 1434 additions and 198 deletions

View File

@@ -6,19 +6,12 @@ The Primaite codebase consists of two main modules:
* ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes.
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules.
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
..
TODO: write up these APIs and link them here.
Game layer
----------
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
PrimAITE Session
^^^^^^^^^^^^^^^^
================
.. admonition:: Deprecated
:class: deprecated
@@ -28,27 +21,76 @@ PrimAITE Session
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents
^^^^^^
======
All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types:
* RL agents action during each step is decided by an appropriate RL algorithm. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable.
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed is settable.
..
TODO: add seed to stochastic scripted agents
Observations
^^^^^^^^^^^^^^^^^^
============
An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class.
Actions
^^^^^^^
=======
An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class.
Rewards
^^^^^^^
=======
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class.
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server.
Reward Components
-----------------
Currently implemented are reward components tailored to the data manipulation scenario. View the full API and description of how they work here: :py:module:`primaite.game.agent.reward`.
Reward Sharing
--------------
An agent's reward can be based on rewards of other agents. This is particularly useful for modelling a situation where the blue agent's job is to protect the ability of green agents to perform their pattern-of-life. This can be configured in the YAML file this way:
```yaml
green_agent_1: # this agent sometimes tries to access the webpage, and sometimes the database
# actions, observations, and agent settings go here
reward_function:
reward_components:
# When the webpage loads, the reward goes up by 0.25 when it fails to load, it goes down to -0.25
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
# When the database is reachable, the reward goes up by 0.05, when it is unreachable it goes down to -0.05
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
blue_agent:
# actions, observations, and agent settings go here
reward_function:
reward_components:
# When the database file is in a good state, blue's reward is 0.4, when it's in a corrupted state the reward is -0.4
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
# The green's reward is added onto the blue's reward.
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
```
When defining agent reward sharing, users must be careful to avoid circular references, as that would lead to an infinite calculation loop. PrimAITE will prevent circular dependencies and provide a helpful error message if they are detected in the yaml.

View File

@@ -76,7 +76,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
@@ -119,7 +126,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
@@ -699,22 +713,15 @@ agents:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
agent_name: client_2_green_user
agent_settings:

View File

@@ -75,7 +75,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
@@ -118,7 +125,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
@@ -700,22 +714,14 @@ agents:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
agent_name: client_2_green_user
agent_settings:
@@ -1259,22 +1265,14 @@ agents:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
agent_name: client_2_green_user
agent_settings:

View File

@@ -1,6 +1,6 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
@@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
class AgentActionHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
timestep: int
"""Timestep of this action."""
action: str
"""CAOS Action name."""
parameters: Dict[str, Any]
"""CAOS parameters for the given action."""
request: RequestFormat
"""The request that was sent to the simulation based on the CAOS action chosen."""
response: RequestResponse
"""The response sent back by the simulator for this action."""
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -90,6 +110,7 @@ class AbstractAgent(ABC):
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.action_history: List[AgentActionHistoryItem] = []
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -109,7 +130,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state)
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -120,8 +141,6 @@ class AbstractAgent(ABC):
:param obs: Observation of the environment.
:type obs: ObsType
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
:type reward: float, optional
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
:type timestep: int
:return: Action to be taken in the environment.
@@ -138,9 +157,15 @@ class AbstractAgent(ABC):
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
def reset_agent_for_episode(self) -> None:
"""Agent reset logic should go here."""
pass
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.action_history.append(
AgentActionHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
)
class AbstractScriptedAgent(AbstractAgent):

View File

@@ -26,11 +26,16 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
from typing_extensions import Never
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
_LOGGER = getLogger(__name__)
@@ -38,7 +43,7 @@ class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -58,7 +63,7 @@ class AbstractReward:
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -98,7 +103,7 @@ class DatabaseFileIntegrity(AbstractReward):
file_name,
]
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -153,7 +158,7 @@ class WebServer404Penalty(AbstractReward):
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -203,16 +208,27 @@ class WebpageUnavailablePenalty(AbstractReward):
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""
Calculate the reward based on current simulation state.
Calculate the reward based on current simulation state, and the recent agent action.
:param state: The current state of the simulation.
:type state: Dict
When the green agent requests to execute the browser application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
_LOGGER.info(
@@ -252,16 +268,28 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""
Calculate the reward based on current simulation state.
Calculate the reward based on current simulation state, and the recent agent action.
:param state: The current state of the simulation.
:type state: Dict
When the green agent requests to execute the database client application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
@@ -284,6 +312,51 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return cls(node_hostname=node_hostname)
class SharedReward(AbstractReward):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
"""
Initialise the shared reward.
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_name: Optional[str]
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback(self.agent_name)
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":
"""
Build the SharedReward object from config.
:param config: Configuration dictionary
:type config: Dict
"""
agent_name = config.get("agent_name")
return cls(agent_name=agent_name)
class RewardFunction:
"""Manages the reward function for the agent."""
@@ -293,6 +366,7 @@ class RewardFunction:
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
}
"""List of reward class identifiers."""
@@ -313,7 +387,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def update(self, state: Dict) -> float:
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -323,7 +397,7 @@ class RewardFunction:
for comp_and_weight in self.reward_components:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
self.current_reward = total
return self.current_reward

View File

@@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_agent_for_episode()
self.setup_agent()
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def reset_agent_for_episode(self) -> None:
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
super().reset_agent_for_episode()
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)

View File

@@ -1,6 +1,6 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
@@ -8,9 +8,10 @@ from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -115,6 +116,9 @@ class PrimaiteGame:
self.save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
self._reward_calculation_order: List[str] = [name for name in self.agents]
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -135,49 +139,49 @@ class PrimaiteGame:
"""
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
if self.step_counter == 0:
state = self.get_sim_state()
for agent in self.agents.values():
agent.update_observation(state=state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state, and the response from the last action
self.update_agents(state=sim_state)
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for _, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
for agent_name in self._reward_calculation_order:
agent = self.agents[agent_name]
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward
def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]:
"""
Apply all actions to simulation as requests.
:return: A recap of each agent's actions, in CAOS format.
:rtype: Dict[str, Tuple[str, Dict]]
"""
agent_actions = {}
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
request = agent.format_request(action_choice, options)
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
request = agent.format_request(action_choice, parameters)
response = self.simulation.apply_request(request)
agent_actions[agent.agent_name] = {
"action": action_choice,
"parameters": options,
"response": response.model_dump(),
}
return agent_actions
agent.process_action_response(
timestep=self.step_counter,
action=action_choice,
parameters=parameters,
request=request,
response=response,
)
def advance_timestep(self) -> None:
"""Advance timestep."""
@@ -453,7 +457,49 @@ class PrimaiteGame:
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game
def setup_reward_sharing(self):
"""Do necessary setup to enable reward sharing between agents.
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
sharing their reward. This callback is provided by this setup method.
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
that rely on the value of another reward are evaluated later.
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
"""
# construct dependency graph in the reward sharing between agents.
graph = {}
for name, agent in self.agents.items():
graph[name] = set()
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
if graph_has_cycle(graph):
raise RuntimeError(
(
"Detected cycle in agent reward sharing. Check the agent reward function ",
"configuration: reward sharing can only go one way.",
)
)
# sort the agents so the rewards that depend on other rewards are always evaluated later
self._reward_calculation_order = topological_sort(graph)

View File

@@ -1,4 +1,5 @@
from random import random
from typing import Any, Iterable, Mapping
def simulate_trial(p_of_success: float) -> bool:
@@ -14,3 +15,80 @@ def simulate_trial(p_of_success: float) -> bool:
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
"""
return random() < p_of_success
def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool:
"""Detect cycles in a directed graph.
Provide the graph as a dictionary that describes which nodes are linked. For example:
{0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0
{'a': ('b','c'), c:('b')} here there is no cycle
:param graph: a mapping from node to a set of nodes to which it is connected.
:type graph: Mapping[Any, Iterable[Any]]
:return: Whether the graph has any cycles
:rtype: bool
"""
visited = set()
currently_visiting = set()
def depth_first_search(node: Any) -> bool:
"""Perform depth-first search (DFS) traversal to detect cycles starting from a given node."""
if node in currently_visiting:
return True # Cycle detected
if node in visited:
return False # Already visited, no need to explore further
visited.add(node)
currently_visiting.add(node)
for neighbour in graph.get(node, []):
if depth_first_search(neighbour):
return True # Cycle detected
currently_visiting.remove(node)
return False
# Start DFS traversal from each node
for node in graph:
if depth_first_search(node):
return True # Cycle detected
return False # No cycles found
def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
"""
Perform topological sorting on a directed graph.
This guarantees that if there's a directed edge from node A to node B, then A appears before B.
:param graph: A dictionary representing the directed graph, where keys are node identifiers
and values are lists of outgoing edges from each node.
:type graph: dict[int, list[Any]]
:return: A topologically sorted list of node identifiers.
:rtype: list[Any]
"""
visited: set[Any] = set()
stack: list[Any] = []
def dfs(node: Any) -> None:
"""
Depth-first search traversal to visit nodes and their neighbors.
:param node: The current node to visit.
:type node: Any
"""
if node in visited:
return
visited.add(node)
for neighbour in graph.get(node, []):
dfs(neighbour)
stack.append(node)
# Perform DFS traversal from each node
for node in graph:
dfs(node)
return stack

View File

@@ -1,7 +1,9 @@
from typing import Dict, ForwardRef, Literal
from typing import Dict, ForwardRef, List, Literal, Union
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
RequestFormat = List[Union[str, int, float]]
RequestResponse = ForwardRef("RequestResponse")
"""This makes it possible to type-hint RequestResponse.from_bool return type."""

View File

@@ -373,7 +373,7 @@
"# Imports\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"import yaml\n",
"from pprint import pprint\n"
]
@@ -425,14 +425,14 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info['action']\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str\n"
" return red_str"
]
},
{
@@ -450,7 +450,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now the reward is -1, let's have a look at blue agent's observation."
"Now the reward is -0.8, let's have a look at blue agent's observation."
]
},
{
@@ -510,9 +510,9 @@
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n",
"print(f\"Blue reward:{reward}\" )"
]
},
@@ -533,9 +533,9 @@
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
"obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
"print(f\"Blue reward:{reward:.2f}\" )"
@@ -557,17 +557,19 @@
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(50) # Block client 1\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(51) # Block client 2\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"for step in range(30):\n",
"while abs(reward - 0.8) > 1e-5:\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
" if env.game.step_counter > 10000:\n",
" break # make sure there's no infinite loop if something went wrong"
]
},
{
@@ -617,17 +619,19 @@
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's block it\n",
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
" print(\"blocking client 1\")\n",
" break\n",
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 2 has NMNEs, so let's block it\n",
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
" print(\"blocking client 2\")\n",
" break\n",
" if tries>100:\n",
" print(\"Error: NMNE never increased\")\n",
" break\n",
"\n",
"env.step(13) # Patch the database\n",
"..."
"print()\n"
]
},
{
@@ -646,14 +650,14 @@
"\n",
"for step in range(40):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)"
]
},
{

View File

@@ -49,23 +49,20 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
info = {
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
@@ -91,13 +88,13 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
self.game.update_agents(state=state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@@ -192,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
@@ -217,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -236,10 +233,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):

View File

@@ -48,8 +48,6 @@ class PrimaiteIO:
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
self.agent_action_log: List[Dict] = []
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
@@ -72,48 +70,24 @@ class PrimaiteIO:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
"""Cache agent actions for a particular step.
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
CAOS action is a string representing one the CAOS actions.
parameters is a dict of parameter names and values for that particular CAOS action.
For example:
{
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
'defender': ('DO_NOTHING', {})
}
:type agent_actions: Dict
:param timestep: Simulation timestep when these actions occurred.
:type timestep: int
"""
self.agent_action_log.append(
[
{
"episode": episode,
"timestep": timestep,
"agent_actions": agent_actions,
}
]
)
def write_agent_actions(self, episode: int) -> None:
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number
:type episode: int
"""
data = {}
longest_history = max([len(hist) for hist in agent_actions.values()])
for i in range(longest_history):
data[i] = {"timestep": i, "episode": episode}
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(self.agent_action_log, fp=file, indent=1)
def clear_agent_actions(self) -> None:
"""Reset the agent action log back to an empty dictionary."""
self.agent_action_log = []
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":

View File

@@ -7,12 +7,10 @@ from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.interface.request import RequestFormat, RequestResponse
_LOGGER = getLogger(__name__)
RequestFormat = List[Union[str, int, float]]
class RequestPermissionValidator(BaseModel):
"""

View File

@@ -0,0 +1,956 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true
game:
max_episode_length: 256
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_PATCH
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 1
service_id: 0
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 1
service_id: 0
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 1
service_id: 0
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 1
service_id: 0
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 1
service_id: 0
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 1
service_id: 0
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 1
service_id: 0
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 1
service_id: 0
9: # check database.db file
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
options:
node_id: 2
service_id: 0
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
node_id: 0
19:
action: "NODE_SHUTDOWN"
options:
node_id: 0
20:
action: NODE_STARTUP
options:
node_id: 0
21:
action: NODE_RESET
options:
node_id: 0
22:
action: "NODE_OS_SCAN"
options:
node_id: 1
23:
action: "NODE_SHUTDOWN"
options:
node_id: 1
24:
action: NODE_STARTUP
options:
node_id: 1
25:
action: NODE_RESET
options:
node_id: 1
26: # old action num: 18
action: "NODE_OS_SCAN"
options:
node_id: 2
27:
action: "NODE_SHUTDOWN"
options:
node_id: 2
28:
action: NODE_STARTUP
options:
node_id: 2
29:
action: NODE_RESET
options:
node_id: 2
30:
action: "NODE_OS_SCAN"
options:
node_id: 3
31:
action: "NODE_SHUTDOWN"
options:
node_id: 3
32:
action: NODE_STARTUP
options:
node_id: 3
33:
action: NODE_RESET
options:
node_id: 3
34:
action: "NODE_OS_SCAN"
options:
node_id: 4
35:
action: "NODE_SHUTDOWN"
options:
node_id: 4
36:
action: NODE_STARTUP
options:
node_id: 4
37:
action: NODE_RESET
options:
node_id: 4
38:
action: "NODE_OS_SCAN"
options:
node_id: 5
39: # old action num: 19 # shutdown client 1
action: "NODE_SHUTDOWN"
options:
node_id: 5
40: # old action num: 20
action: NODE_STARTUP
options:
node_id: 5
41: # old action num: 21
action: NODE_RESET
options:
node_id: 5
42:
action: "NODE_OS_SCAN"
options:
node_id: 6
43:
action: "NODE_SHUTDOWN"
options:
node_id: 6
44:
action: NODE_STARTUP
options:
node_id: 6
45:
action: NODE_RESET
options:
node_id: 6
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
options:
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
options:
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
50: # old action num: 26
action: "NETWORK_ACL_ADDRULE"
options:
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
51: # old action num: 27
action: "NETWORK_ACL_ADDRULE"
options:
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
52: # old action num: 28
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
53: # old action num: 29
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
54: # old action num: 30
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
55: # old action num: 31
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
56: # old action num: 32
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
57: # old action num: 33
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
58: # old action num: 34
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
59: # old action num: 35
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
60: # old action num: 36
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
61: # old action num: 37
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
62: # old action num: 38
action: "NETWORK_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
63: # old action num: 39
action: "NETWORK_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
64: # old action num: 40
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
65: # old action num: 41
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
66: # old action num: 42
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
67: # old action num: 43
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
68: # old action num: 44
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
69: # old action num: 45
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
70: # old action num: 46
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
71: # old action num: 47
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
72: # old action num: 48
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
73: # old action num: 49
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
74: # old action num: 50
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
75: # old action num: 51
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
76: # old action num: 52
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
77: # old action num: 53
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
reward_function:
reward_components:
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
agent_settings:
flatten_obs: true
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- ref: router_1
hostname: router_1
type: router
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
hostname: switch_1
type: switch
num_ports: 8
- ref: switch_2
hostname: switch_2
type: switch
num_ports: 8
- ref: domain_controller
hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: database_server
hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- ref: database_ftp_client
type: FTPClient
- ref: backup_server
hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: FTPServer
- ref: security_suite
hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
network_interfaces:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_1_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: client_1_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_2_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -531,4 +531,6 @@ def game_and_agent():
game.agents["test_agent"] = test_agent
game.setup_reward_sharing()
return (game, test_agent)

View File

@@ -1,10 +1,16 @@
import yaml
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
from tests import TEST_ASSETS_ROOT
from tests.conftest import ControlledAgent
@@ -61,13 +67,18 @@ def test_uc2_rewards(game_and_agent):
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
db_client.apply_request(
response = db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
assert reward_value == 1.0
router.acl.remove_rule(position=2)
@@ -78,5 +89,32 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
assert reward_value == -1.0
def test_shared_reward():
CFG_PATH = TEST_ASSETS_ROOT / "configs/shared_rewards.yaml"
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env.reset()
order = env.game._reward_calculation_order
assert order.index("defender") > order.index("client_1_green_user")
assert order.index("defender") > order.index("client_2_green_user")
for step in range(256):
act = env.action_space.sample()
env.step(act)
g1_reward = env.game.agents["client_1_green_user"].reward_function.current_reward
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
blue_reward = env.game.agents["defender"].reward_function.current_reward
assert blue_reward == g1_reward + g2_reward