Merged PR 304: Reward sharing
## Summary * add ability for agents to share rewards - the calculated reward value for one agent can be used as a component of another agent's reward. * Update UC2 configs to use the reward sharing functionality - green agents have a reward based on their two actions. Blue agent reward adds green agent rewards to itself. * Add agent action history - This allows the rewards to react to agent actions. * Make action logging use the new agent history. * Make the webpage and database reward components treat failed requests the same as if the webpage was unavailable / database was unreachable. * reorder the PrimaiteGame step to be the same as the Gymnasium env step * update uc2 notebook accordingly. ## Test process Tested with ad-hoc notebooks and debugging tool to verify correct data is being used. Tested notebooks run properly and pytests pass. Added unit and integration tests. ## Checklist - [ ] PR is linked to a **work item** - [ ] **acceptance criteria** of linked ticket are met - [ ] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [ ] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2372
This commit is contained in:
@@ -6,19 +6,12 @@ The Primaite codebase consists of two main modules:
|
||||
* ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes.
|
||||
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules.
|
||||
|
||||
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
|
||||
|
||||
..
|
||||
TODO: write up these APIs and link them here.
|
||||
|
||||
|
||||
Game layer
|
||||
----------
|
||||
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
|
||||
|
||||
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
|
||||
|
||||
PrimAITE Session
|
||||
^^^^^^^^^^^^^^^^
|
||||
================
|
||||
|
||||
.. admonition:: Deprecated
|
||||
:class: deprecated
|
||||
@@ -28,27 +21,76 @@ PrimAITE Session
|
||||
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
|
||||
|
||||
Agents
|
||||
^^^^^^
|
||||
======
|
||||
|
||||
All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types:
|
||||
|
||||
* RL agents action during each step is decided by an appropriate RL algorithm. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
|
||||
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable.
|
||||
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed is settable.
|
||||
|
||||
..
|
||||
TODO: add seed to stochastic scripted agents
|
||||
|
||||
Observations
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
============
|
||||
|
||||
An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class.
|
||||
|
||||
Actions
|
||||
^^^^^^^
|
||||
=======
|
||||
|
||||
An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class.
|
||||
|
||||
Rewards
|
||||
^^^^^^^
|
||||
=======
|
||||
|
||||
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class.
|
||||
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server.
|
||||
|
||||
Reward Components
|
||||
-----------------
|
||||
|
||||
Currently implemented are reward components tailored to the data manipulation scenario. View the full API and description of how they work here: :py:module:`primaite.game.agent.reward`.
|
||||
|
||||
Reward Sharing
|
||||
--------------
|
||||
|
||||
An agent's reward can be based on rewards of other agents. This is particularly useful for modelling a situation where the blue agent's job is to protect the ability of green agents to perform their pattern-of-life. This can be configured in the YAML file this way:
|
||||
|
||||
```yaml
|
||||
green_agent_1: # this agent sometimes tries to access the webpage, and sometimes the database
|
||||
# actions, observations, and agent settings go here
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
# When the webpage loads, the reward goes up by 0.25 when it fails to load, it goes down to -0.25
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
# When the database is reachable, the reward goes up by 0.05, when it is unreachable it goes down to -0.05
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
blue_agent:
|
||||
# actions, observations, and agent settings go here
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
# When the database file is in a good state, blue's reward is 0.4, when it's in a corrupted state the reward is -0.4
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.40
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
# The green's reward is added onto the blue's reward.
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_2_green_user
|
||||
|
||||
```
|
||||
|
||||
When defining agent reward sharing, users must be careful to avoid circular references, as that would lead to an infinite calculation loop. PrimAITE will prevent circular dependencies and provide a helpful error message if they are detected in the yaml.
|
||||
|
||||
@@ -76,7 +76,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
@@ -119,7 +126,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
|
||||
|
||||
@@ -699,22 +713,15 @@ agents:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
agent_name: client_1_green_user
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -75,7 +75,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
@@ -118,7 +125,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
|
||||
|
||||
@@ -700,22 +714,14 @@ agents:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
agent_name: client_1_green_user
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
agent_settings:
|
||||
@@ -1259,22 +1265,14 @@ agents:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
agent_name: client_1_green_user
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentActionHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
|
||||
timestep: int
|
||||
"""Timestep of this action."""
|
||||
|
||||
action: str
|
||||
"""CAOS Action name."""
|
||||
|
||||
parameters: Dict[str, Any]
|
||||
"""CAOS parameters for the given action."""
|
||||
|
||||
request: RequestFormat
|
||||
"""The request that was sent to the simulation based on the CAOS action chosen."""
|
||||
|
||||
response: RequestResponse
|
||||
"""The response sent back by the simulator for this action."""
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
@@ -90,6 +110,7 @@ class AbstractAgent(ABC):
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.action_history: List[AgentActionHistoryItem] = []
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -109,7 +130,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state)
|
||||
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -120,8 +141,6 @@ class AbstractAgent(ABC):
|
||||
|
||||
:param obs: Observation of the environment.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
|
||||
:type reward: float, optional
|
||||
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
|
||||
:type timestep: int
|
||||
:return: Action to be taken in the environment.
|
||||
@@ -138,9 +157,15 @@ class AbstractAgent(ABC):
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Agent reset logic should go here."""
|
||||
pass
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.action_history.append(
|
||||
AgentActionHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
|
||||
@@ -26,11 +26,16 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -38,7 +43,7 @@ class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -58,7 +63,7 @@ class AbstractReward:
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -98,7 +103,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -153,7 +158,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -203,16 +208,27 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
When the green agent requests to execute the browser application, and that request fails, this reward
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last webpage
|
||||
had a 200 status code, because there has been an unsuccessful request since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request did actually go through, then check if the webpage also loaded
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
_LOGGER.info(
|
||||
@@ -252,16 +268,28 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
When the green agent requests to execute the database client application, and that request fails, this reward
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last successful
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
@@ -284,6 +312,51 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialise the shared reward.
|
||||
|
||||
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
correctly.
|
||||
|
||||
:param agent_name: The name whose reward is an input
|
||||
:type agent_name: Optional[str]
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
"""Agent whose reward to track."""
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "SharedReward":
|
||||
"""
|
||||
Build the SharedReward object from config.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:type config: Dict
|
||||
"""
|
||||
agent_name = config.get("agent_name")
|
||||
return cls(agent_name=agent_name)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
@@ -293,6 +366,7 @@ class RewardFunction:
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
"SHARED_REWARD": SharedReward,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
@@ -313,7 +387,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def update(self, state: Dict) -> float:
|
||||
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -323,7 +397,7 @@ class RewardFunction:
|
||||
for comp_and_weight in self.reward_components:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
|
||||
self.current_reward = total
|
||||
return self.current_reward
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.reset_agent_for_episode()
|
||||
self.setup_agent()
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
@@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
super().reset_agent_for_episode()
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -8,9 +8,10 @@ from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -115,6 +116,9 @@ class PrimaiteGame:
|
||||
self.save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
self._reward_calculation_order: List[str] = [name for name in self.agents]
|
||||
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -135,49 +139,49 @@ class PrimaiteGame:
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state
|
||||
self.update_agents(sim_state)
|
||||
|
||||
if self.step_counter == 0:
|
||||
state = self.get_sim_state()
|
||||
for agent in self.agents.values():
|
||||
agent.update_observation(state=state)
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state, and the response from the last action
|
||||
self.update_agents(state=sim_state)
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
return self.simulation.describe_state()
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for _, agent in self.agents.items():
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
for agent_name in self._reward_calculation_order:
|
||||
agent = self.agents[agent_name]
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]:
|
||||
"""
|
||||
Apply all actions to simulation as requests.
|
||||
|
||||
:return: A recap of each agent's actions, in CAOS format.
|
||||
:rtype: Dict[str, Tuple[str, Dict]]
|
||||
|
||||
"""
|
||||
agent_actions = {}
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
request = agent.format_request(action_choice, options)
|
||||
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
|
||||
request = agent.format_request(action_choice, parameters)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent_actions[agent.agent_name] = {
|
||||
"action": action_choice,
|
||||
"parameters": options,
|
||||
"response": response.model_dump(),
|
||||
}
|
||||
return agent_actions
|
||||
agent.process_action_response(
|
||||
timestep=self.step_counter,
|
||||
action=action_choice,
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
@@ -453,7 +457,49 @@ class PrimaiteGame:
|
||||
raise ValueError(msg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
"""Do necessary setup to enable reward sharing between agents.
|
||||
|
||||
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
|
||||
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
|
||||
|
||||
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
|
||||
sharing their reward. This callback is provided by this setup method.
|
||||
|
||||
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
|
||||
that rely on the value of another reward are evaluated later.
|
||||
|
||||
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
|
||||
"""
|
||||
# construct dependency graph in the reward sharing between agents.
|
||||
graph = {}
|
||||
for name, agent in self.agents.items():
|
||||
graph[name] = set()
|
||||
for comp, weight in agent.reward_function.reward_components:
|
||||
if isinstance(comp, SharedReward):
|
||||
comp: SharedReward
|
||||
graph[name].add(comp.agent_name)
|
||||
|
||||
# while constructing the graph, we might as well set up the reward sharing itself.
|
||||
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
|
||||
|
||||
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
|
||||
if graph_has_cycle(graph):
|
||||
raise RuntimeError(
|
||||
(
|
||||
"Detected cycle in agent reward sharing. Check the agent reward function ",
|
||||
"configuration: reward sharing can only go one way.",
|
||||
)
|
||||
)
|
||||
|
||||
# sort the agents so the rewards that depend on other rewards are always evaluated later
|
||||
self._reward_calculation_order = topological_sort(graph)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from random import random
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
|
||||
def simulate_trial(p_of_success: float) -> bool:
|
||||
@@ -14,3 +15,80 @@ def simulate_trial(p_of_success: float) -> bool:
|
||||
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
|
||||
"""
|
||||
return random() < p_of_success
|
||||
|
||||
|
||||
def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool:
|
||||
"""Detect cycles in a directed graph.
|
||||
|
||||
Provide the graph as a dictionary that describes which nodes are linked. For example:
|
||||
{0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0
|
||||
{'a': ('b','c'), c:('b')} here there is no cycle
|
||||
|
||||
:param graph: a mapping from node to a set of nodes to which it is connected.
|
||||
:type graph: Mapping[Any, Iterable[Any]]
|
||||
:return: Whether the graph has any cycles
|
||||
:rtype: bool
|
||||
"""
|
||||
visited = set()
|
||||
currently_visiting = set()
|
||||
|
||||
def depth_first_search(node: Any) -> bool:
|
||||
"""Perform depth-first search (DFS) traversal to detect cycles starting from a given node."""
|
||||
if node in currently_visiting:
|
||||
return True # Cycle detected
|
||||
if node in visited:
|
||||
return False # Already visited, no need to explore further
|
||||
|
||||
visited.add(node)
|
||||
currently_visiting.add(node)
|
||||
|
||||
for neighbour in graph.get(node, []):
|
||||
if depth_first_search(neighbour):
|
||||
return True # Cycle detected
|
||||
|
||||
currently_visiting.remove(node)
|
||||
return False
|
||||
|
||||
# Start DFS traversal from each node
|
||||
for node in graph:
|
||||
if depth_first_search(node):
|
||||
return True # Cycle detected
|
||||
|
||||
return False # No cycles found
|
||||
|
||||
|
||||
def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
|
||||
"""
|
||||
Perform topological sorting on a directed graph.
|
||||
|
||||
This guarantees that if there's a directed edge from node A to node B, then A appears before B.
|
||||
|
||||
:param graph: A dictionary representing the directed graph, where keys are node identifiers
|
||||
and values are lists of outgoing edges from each node.
|
||||
:type graph: dict[int, list[Any]]
|
||||
|
||||
:return: A topologically sorted list of node identifiers.
|
||||
:rtype: list[Any]
|
||||
"""
|
||||
visited: set[Any] = set()
|
||||
stack: list[Any] = []
|
||||
|
||||
def dfs(node: Any) -> None:
|
||||
"""
|
||||
Depth-first search traversal to visit nodes and their neighbors.
|
||||
|
||||
:param node: The current node to visit.
|
||||
:type node: Any
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for neighbour in graph.get(node, []):
|
||||
dfs(neighbour)
|
||||
stack.append(node)
|
||||
|
||||
# Perform DFS traversal from each node
|
||||
for node in graph:
|
||||
dfs(node)
|
||||
|
||||
return stack
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
from typing import Dict, ForwardRef, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint\n"
|
||||
]
|
||||
@@ -425,14 +425,14 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info['action']\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str\n"
|
||||
" return red_str"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -450,7 +450,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now the reward is -1, let's have a look at blue agent's observation."
|
||||
"Now the reward is -0.8, let's have a look at blue agent's observation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -510,9 +510,9 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n",
|
||||
"print(f\"Blue reward:{reward}\" )"
|
||||
]
|
||||
},
|
||||
@@ -533,9 +533,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
|
||||
"print(f\"Blue reward:{reward:.2f}\" )"
|
||||
@@ -557,17 +557,19 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(50) # Block client 1\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(51) # Block client 2\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"for step in range(30):\n",
|
||||
"while abs(reward - 0.8) > 1e-5:\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
" if env.game.step_counter > 10000:\n",
|
||||
" break # make sure there's no infinite loop if something went wrong"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -617,17 +619,19 @@
|
||||
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
|
||||
" print(\"blocking client 1\")\n",
|
||||
" break\n",
|
||||
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 2 has NMNEs, so let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
|
||||
" print(\"blocking client 2\")\n",
|
||||
" break\n",
|
||||
" if tries>100:\n",
|
||||
" print(\"Error: NMNE never increased\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"env.step(13) # Patch the database\n",
|
||||
"..."
|
||||
"print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -646,14 +650,14 @@
|
||||
"\n",
|
||||
"for step in range(40):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -49,23 +49,20 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
info = {
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
@@ -91,13 +88,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
self.game.update_agents(state=state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
@@ -192,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
@@ -217,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -236,10 +233,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
|
||||
@@ -48,8 +48,6 @@ class PrimaiteIO:
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
|
||||
self.agent_action_log: List[Dict] = []
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
if timestamp is None:
|
||||
@@ -72,48 +70,24 @@ class PrimaiteIO:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
|
||||
"""Cache agent actions for a particular step.
|
||||
|
||||
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
|
||||
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
|
||||
CAOS action is a string representing one the CAOS actions.
|
||||
parameters is a dict of parameter names and values for that particular CAOS action.
|
||||
For example:
|
||||
{
|
||||
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
|
||||
'defender': ('DO_NOTHING', {})
|
||||
}
|
||||
:type agent_actions: Dict
|
||||
:param timestep: Simulation timestep when these actions occurred.
|
||||
:type timestep: int
|
||||
"""
|
||||
self.agent_action_log.append(
|
||||
[
|
||||
{
|
||||
"episode": episode,
|
||||
"timestep": timestep,
|
||||
"agent_actions": agent_actions,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def write_agent_actions(self, episode: int) -> None:
|
||||
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
:type episode: int
|
||||
"""
|
||||
data = {}
|
||||
longest_history = max([len(hist) for hist in agent_actions.values()])
|
||||
for i in range(longest_history):
|
||||
data[i] = {"timestep": i, "episode": episode}
|
||||
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
|
||||
|
||||
path = self.generate_agent_actions_save_path(episode=episode)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
path.touch()
|
||||
_LOGGER.info(f"Saving agent action log to {path}")
|
||||
with open(path, "w") as file:
|
||||
json.dump(self.agent_action_log, fp=file, indent=1)
|
||||
|
||||
def clear_agent_actions(self) -> None:
|
||||
"""Reset the agent action log back to an empty dictionary."""
|
||||
self.agent_action_log = []
|
||||
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
|
||||
@@ -7,12 +7,10 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
|
||||
class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
|
||||
956
tests/assets/configs/shared_rewards.yaml
Normal file
956
tests/assets/configs/shared_rewards.yaml
Normal file
@@ -0,0 +1,956 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
- ref: data_manipulation_attacker
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_PATCH
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_hostname: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
9: # check database.db file
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_PATCH"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 0
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 0
|
||||
20:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 0
|
||||
21:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 0
|
||||
22:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 1
|
||||
23:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 1
|
||||
24:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 1
|
||||
25:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 1
|
||||
26: # old action num: 18
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
27:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 2
|
||||
28:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 2
|
||||
29:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 2
|
||||
30:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
31:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 3
|
||||
32:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 3
|
||||
33:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 3
|
||||
34:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 4
|
||||
35:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 4
|
||||
36:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 4
|
||||
37:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 4
|
||||
38:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 5
|
||||
39: # old action num: 19 # shutdown client 1
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 5
|
||||
40: # old action num: 20
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 5
|
||||
41: # old action num: 21
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 5
|
||||
42:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 6
|
||||
43:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
44:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 6
|
||||
45:
|
||||
action: NODE_RESET
|
||||
options:
|
||||
node_id: 6
|
||||
|
||||
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
50: # old action num: 26
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
51: # old action num: 27
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
52: # old action num: 28
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
53: # old action num: 29
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
54: # old action num: 30
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
55: # old action num: 31
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
56: # old action num: 32
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
57: # old action num: 33
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
58: # old action num: 34
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
59: # old action num: 35
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
60: # old action num: 36
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
61: # old action num: 37
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
62: # old action num: 38
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
63: # old action num: 39
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
64: # old action num: 40
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
65: # old action num: 41
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
66: # old action num: 42
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
67: # old action num: 43
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
68: # old action num: 44
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
69: # old action num: 45
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
70: # old action num: 46
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
71: # old action num: 47
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
72: # old action num: 48
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
73: # old action num: 49
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
74: # old action num: 50
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
75: # old action num: 51
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
76: # old action num: 52
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
77: # old action num: 53
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_address_order:
|
||||
- node_name: domain_controller
|
||||
nic_num: 1
|
||||
- node_name: web_server
|
||||
nic_num: 1
|
||||
- node_name: database_server
|
||||
nic_num: 1
|
||||
- node_name: backup_server
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 1
|
||||
- node_name: client_1
|
||||
nic_num: 1
|
||||
- node_name: client_2
|
||||
nic_num: 1
|
||||
- node_name: security_suite
|
||||
nic_num: 2
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_1_green_user
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
capture_nmne: true
|
||||
nmne_capture_keywords:
|
||||
- DELETE
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
hostname: router_1
|
||||
type: router
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.10.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
18:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
19:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
20:
|
||||
action: PERMIT
|
||||
src_port: FTP
|
||||
dst_port: FTP
|
||||
21:
|
||||
action: PERMIT
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
hostname: switch_2
|
||||
type: switch
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
hostname: domain_controller
|
||||
type: server
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
hostname: web_server
|
||||
type: server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
applications:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
|
||||
- ref: database_server
|
||||
hostname: database_server
|
||||
type: server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
options:
|
||||
backup_server_ip: 192.168.1.16
|
||||
- ref: database_ftp_client
|
||||
type: FTPClient
|
||||
|
||||
- ref: backup_server
|
||||
hostname: backup_server
|
||||
type: server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: FTPServer
|
||||
|
||||
- ref: security_suite
|
||||
hostname: security_suite
|
||||
type: server
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
network_interfaces:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: client_1_web_browser
|
||||
type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- ref: client_1_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
- ref: client_2_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
@@ -531,4 +531,6 @@ def game_and_agent():
|
||||
|
||||
game.agents["test_agent"] = test_agent
|
||||
|
||||
game.setup_reward_sharing()
|
||||
|
||||
return (game, test_agent)
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import ControlledAgent
|
||||
|
||||
|
||||
@@ -61,13 +67,18 @@ def test_uc2_rewards(game_and_agent):
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
|
||||
db_client.apply_request(
|
||||
response = db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state)
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
assert reward_value == 1.0
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
@@ -78,5 +89,32 @@ def test_uc2_rewards(game_and_agent):
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state)
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
assert reward_value == -1.0
|
||||
|
||||
|
||||
def test_shared_reward():
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/shared_rewards.yaml"
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
env.reset()
|
||||
|
||||
order = env.game._reward_calculation_order
|
||||
assert order.index("defender") > order.index("client_1_green_user")
|
||||
assert order.index("defender") > order.index("client_2_green_user")
|
||||
|
||||
for step in range(256):
|
||||
act = env.action_space.sample()
|
||||
env.step(act)
|
||||
g1_reward = env.game.agents["client_1_green_user"].reward_function.current_reward
|
||||
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
|
||||
blue_reward = env.game.agents["defender"].reward_function.current_reward
|
||||
assert blue_reward == g1_reward + g2_reward
|
||||
|
||||
Reference in New Issue
Block a user