fix reward logging

This commit is contained in:
Marek Wolan
2024-05-31 15:00:18 +01:00
parent 04dc486708
commit c5f131ece5
11 changed files with 44 additions and 37 deletions

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
pass
class AgentActionHistoryItem(BaseModel):
class AgentHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
timestep: int
@@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel):
response: RequestResponse
"""The response sent back by the simulator for this action."""
reward: Optional[float] = None
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -110,7 +112,7 @@ class AbstractAgent(ABC):
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.action_history: List[AgentActionHistoryItem] = []
self.history: List[AgentHistoryItem] = []
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -130,7 +132,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
return self.reward_function.update(state=state, last_action_response=self.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -161,12 +163,16 @@ class AbstractAgent(ABC):
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.action_history.append(
AgentActionHistoryItem(
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
)
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""

View File

@@ -34,7 +34,7 @@ from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
@@ -44,7 +44,7 @@ class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -64,7 +64,7 @@ class AbstractReward:
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
file_name,
]
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward):
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
@@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
@@ -343,7 +343,7 @@ class SharedReward(AbstractReward):
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback(self.agent_name)
@@ -389,7 +389,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.

View File

@@ -160,6 +160,7 @@ class PrimaiteGame:
agent = self.agents[agent_name]
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.save_reward_to_history()
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward

View File

@@ -22,7 +22,7 @@
"# Imports\n",
"\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.game.agent.interface import AgentHistoryItem\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml\n",
"from pprint import pprint"
@@ -63,7 +63,7 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",

View File

@@ -392,7 +392,7 @@
"# Imports\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.game.agent.interface import AgentHistoryItem\n",
"import yaml\n",
"from pprint import pprint\n"
]
@@ -444,7 +444,7 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
@@ -705,7 +705,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -25,7 +25,7 @@
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"import ray\n",

View File

@@ -298,8 +298,8 @@
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_A'].action_history[i].action\n",
" red_action = env.game.agents['red_A'].action_history[i].action\n",
" green_action = env.game.agents['green_A'].history[i].action\n",
" red_action = env.game.agents['red_A'].history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]
@@ -329,8 +329,8 @@
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_B'].action_history[i].action\n",
" red_action = env.game.agents['red_B'].action_history[i].action\n",
" green_action = env.game.agents['green_B'].history[i].action\n",
" red_action = env.game.agents['red_B'].history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]

View File

@@ -60,7 +60,7 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(step, action, state, reward)
@@ -89,8 +89,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
@@ -125,5 +125,5 @@ class PrimaiteGymEnv(gymnasium.Env):
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)

View File

@@ -87,7 +87,7 @@ class PrimaiteIO:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number

View File

@@ -59,8 +59,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
@@ -138,8 +138,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):

View File

@@ -1,6 +1,6 @@
import yaml
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
@@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent):
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
@@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent):
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)