Add agent action history
This commit is contained in:
@@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.reset_agent_for_episode()
|
||||
self.setup_agent()
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
@@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
super().reset_agent_for_episode()
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentActionHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
|
||||
timestep: int
|
||||
"""Timestep of this action."""
|
||||
|
||||
action: str
|
||||
"""CAOS Action name."""
|
||||
|
||||
parameters: Dict[str, Any]
|
||||
"""CAOS parameters for the given action."""
|
||||
|
||||
request: RequestFormat
|
||||
"""The request that was sent to the simulation based on the CAOS action chosen."""
|
||||
|
||||
response: RequestResponse
|
||||
"""The response sent back by the simulator for this action."""
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
@@ -90,6 +110,7 @@ class AbstractAgent(ABC):
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.action_history: List[AgentActionHistoryItem] = []
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -109,7 +130,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state)
|
||||
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -138,9 +159,15 @@ class AbstractAgent(ABC):
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
def reset_agent_for_episode(self) -> None:
|
||||
"""Agent reset logic should go here."""
|
||||
pass
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.action_history.append(
|
||||
AgentActionHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
|
||||
@@ -26,11 +26,14 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -38,7 +41,9 @@ class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -58,7 +63,9 @@ class AbstractReward:
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -98,7 +105,9 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -153,7 +162,9 @@ class WebServer404Penalty(AbstractReward):
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -206,7 +217,9 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
|
||||
@@ -255,13 +268,17 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]:
|
||||
pass # TODO
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
@@ -313,7 +330,9 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def update(self, state: Dict) -> float:
|
||||
def update(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -323,7 +342,7 @@ class RewardFunction:
|
||||
for comp_and_weight in self.reward_components:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
|
||||
self.current_reward = total
|
||||
return self.current_reward
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -130,49 +130,44 @@ class PrimaiteGame:
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
action_data = self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state, and the response from the last action
|
||||
self.update_agents(state=sim_state, action_data=action_data)
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
return self.simulation.describe_state()
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for _, agent in self.agents.items():
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
for agent_name, agent in self.agents.items():
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.update_observation(state=state)
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]:
|
||||
"""
|
||||
Apply all actions to simulation as requests.
|
||||
|
||||
:return: A recap of each agent's actions, in CAOS format.
|
||||
:rtype: Dict[str, Tuple[str, Dict]]
|
||||
|
||||
"""
|
||||
agent_actions = {}
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
for _, agent in self.agents.items():
|
||||
obs = agent.observation_manager.current_observation
|
||||
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
|
||||
request = agent.format_request(action_choice, options)
|
||||
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
|
||||
request = agent.format_request(action_choice, parameters)
|
||||
response = self.simulation.apply_request(request)
|
||||
agent_actions[agent.agent_name] = {
|
||||
"action": action_choice,
|
||||
"parameters": options,
|
||||
"response": response.model_dump(),
|
||||
}
|
||||
return agent_actions
|
||||
agent.process_action_response(
|
||||
timestep=self.step_counter,
|
||||
action=action_choice,
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
from typing import Dict, ForwardRef, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint\n"
|
||||
]
|
||||
@@ -425,14 +425,14 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info['action']\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str\n"
|
||||
" return red_str"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -49,23 +49,20 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
info = {
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
@@ -91,13 +88,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
self.game.update_agents(state=state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
@@ -217,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -236,10 +233,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
|
||||
@@ -48,8 +48,6 @@ class PrimaiteIO:
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
|
||||
self.agent_action_log: List[Dict] = []
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
if timestamp is None:
|
||||
@@ -72,48 +70,23 @@ class PrimaiteIO:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
|
||||
"""Cache agent actions for a particular step.
|
||||
|
||||
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
|
||||
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
|
||||
CAOS action is a string representing one the CAOS actions.
|
||||
parameters is a dict of parameter names and values for that particular CAOS action.
|
||||
For example:
|
||||
{
|
||||
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
|
||||
'defender': ('DO_NOTHING', {})
|
||||
}
|
||||
:type agent_actions: Dict
|
||||
:param timestep: Simulation timestep when these actions occurred.
|
||||
:type timestep: int
|
||||
"""
|
||||
self.agent_action_log.append(
|
||||
[
|
||||
{
|
||||
"episode": episode,
|
||||
"timestep": timestep,
|
||||
"agent_actions": agent_actions,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def write_agent_actions(self, episode: int) -> None:
|
||||
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
:type episode: int
|
||||
"""
|
||||
data = {}
|
||||
longest_history = max([len(hist) for hist in agent_actions])
|
||||
for i in range(longest_history):
|
||||
data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}}
|
||||
|
||||
path = self.generate_agent_actions_save_path(episode=episode)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
path.touch()
|
||||
_LOGGER.info(f"Saving agent action log to {path}")
|
||||
with open(path, "w") as file:
|
||||
json.dump(self.agent_action_log, fp=file, indent=1)
|
||||
|
||||
def clear_agent_actions(self) -> None:
|
||||
"""Reset the agent action log back to an empty dictionary."""
|
||||
self.agent_action_log = []
|
||||
json.dump(data, fp=file, indent=1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
|
||||
@@ -7,12 +7,10 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, ConfigDict, Field, validate_call
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
RequestFormat = List[Union[str, int, float]]
|
||||
|
||||
|
||||
class RequestPermissionValidator(BaseModel):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user