Add agent action history

This commit is contained in:
Marek Wolan
2024-03-11 20:10:08 +00:00
parent 66ab5ec980
commit 7599655879
9 changed files with 110 additions and 104 deletions

View File

@@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_agent_for_episode()
self.setup_agent()
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def reset_agent_for_episode(self) -> None:
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
super().reset_agent_for_episode()
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)

View File

@@ -1,6 +1,6 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
@@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
class AgentActionHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
timestep: int
"""Timestep of this action."""
action: str
"""CAOS Action name."""
parameters: Dict[str, Any]
"""CAOS parameters for the given action."""
request: RequestFormat
"""The request that was sent to the simulation based on the CAOS action chosen."""
response: RequestResponse
"""The response sent back by the simulator for this action."""
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -90,6 +110,7 @@ class AbstractAgent(ABC):
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.action_history: List[AgentActionHistoryItem] = []
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -109,7 +130,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state)
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -138,9 +159,15 @@ class AbstractAgent(ABC):
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
def reset_agent_for_episode(self) -> None:
"""Agent reset logic should go here."""
pass
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.action_history.append(
AgentActionHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
)
class AbstractScriptedAgent(AbstractAgent):

View File

@@ -26,11 +26,14 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
_LOGGER = getLogger(__name__)
@@ -38,7 +41,9 @@ class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict) -> float:
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""Calculate the reward for the current state."""
return 0.0
@@ -58,7 +63,9 @@ class AbstractReward:
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict) -> float:
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""Calculate the reward for the current state."""
return 0.0
@@ -98,7 +105,9 @@ class DatabaseFileIntegrity(AbstractReward):
file_name,
]
def calculate(self, state: Dict) -> float:
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -153,7 +162,9 @@ class WebServer404Penalty(AbstractReward):
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict) -> float:
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -206,7 +217,9 @@ class WebpageUnavailablePenalty(AbstractReward):
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
def calculate(self, state: Dict) -> float:
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""
Calculate the reward based on current simulation state.
@@ -255,13 +268,17 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
def calculate(self, state: Dict) -> float:
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""
Calculate the reward based on current simulation state.
:param state: The current state of the simulation.
:type state: Dict
"""
if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]:
pass # TODO
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
@@ -313,7 +330,9 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def update(self, state: Dict) -> float:
def update(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -323,7 +342,7 @@ class RewardFunction:
for comp_and_weight in self.reward_components:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
self.current_reward = total
return self.current_reward

View File

@@ -1,6 +1,6 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Tuple
from typing import Dict, List
from pydantic import BaseModel, ConfigDict
@@ -130,49 +130,44 @@ class PrimaiteGame:
"""
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
action_data = self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state, and the response from the last action
self.update_agents(state=sim_state, action_data=action_data)
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for _, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
for agent_name, agent in self.agents.items():
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.update_observation(state=state)
agent.reward_function.total_reward += agent.reward_function.current_reward
def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]:
"""
Apply all actions to simulation as requests.
:return: A recap of each agent's actions, in CAOS format.
:rtype: Dict[str, Tuple[str, Dict]]
"""
agent_actions = {}
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
request = agent.format_request(action_choice, options)
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
request = agent.format_request(action_choice, parameters)
response = self.simulation.apply_request(request)
agent_actions[agent.agent_name] = {
"action": action_choice,
"parameters": options,
"response": response.model_dump(),
}
return agent_actions
agent.process_action_response(
timestep=self.step_counter,
action=action_choice,
parameters=parameters,
request=request,
response=response,
)
def advance_timestep(self) -> None:
"""Advance timestep."""

View File

@@ -1,7 +1,9 @@
from typing import Dict, ForwardRef, Literal
from typing import Dict, ForwardRef, List, Literal, Union
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
RequestFormat = List[Union[str, int, float]]
RequestResponse = ForwardRef("RequestResponse")
"""This makes it possible to type-hint RequestResponse.from_bool return type."""

View File

@@ -373,7 +373,7 @@
"# Imports\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"import yaml\n",
"from pprint import pprint\n"
]
@@ -425,14 +425,14 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info['action']\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str\n"
" return red_str"
]
},
{

View File

@@ -49,23 +49,20 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
info = {
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
@@ -91,13 +88,13 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
self.game.update_agents(state=state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@@ -217,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
agent_actions = self.game.apply_agent_actions()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -236,10 +233,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):

View File

@@ -48,8 +48,6 @@ class PrimaiteIO:
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
self.agent_action_log: List[Dict] = []
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
@@ -72,48 +70,23 @@ class PrimaiteIO:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
"""Cache agent actions for a particular step.
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
CAOS action is a string representing one the CAOS actions.
parameters is a dict of parameter names and values for that particular CAOS action.
For example:
{
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
'defender': ('DO_NOTHING', {})
}
:type agent_actions: Dict
:param timestep: Simulation timestep when these actions occurred.
:type timestep: int
"""
self.agent_action_log.append(
[
{
"episode": episode,
"timestep": timestep,
"agent_actions": agent_actions,
}
]
)
def write_agent_actions(self, episode: int) -> None:
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number
:type episode: int
"""
data = {}
longest_history = max([len(hist) for hist in agent_actions])
for i in range(longest_history):
data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}}
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(self.agent_action_log, fp=file, indent=1)
def clear_agent_actions(self) -> None:
"""Reset the agent action log back to an empty dictionary."""
self.agent_action_log = []
json.dump(data, fp=file, indent=1)
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":

View File

@@ -7,12 +7,10 @@ from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.interface.request import RequestFormat, RequestResponse
_LOGGER = getLogger(__name__)
RequestFormat = List[Union[str, int, float]]
class RequestPermissionValidator(BaseModel):
"""