From c5f131ece59eef137efaee89a141584aca4ae78a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 May 2024 15:00:18 +0100 Subject: [PATCH] fix reward logging --- src/primaite/game/agent/interface.py | 16 +++++++++++----- src/primaite/game/agent/rewards.py | 18 +++++++++--------- src/primaite/game/game.py | 1 + ...ta-Manipulation-Customising-Red-Agent.ipynb | 4 ++-- .../Data-Manipulation-E2E-Demonstration.ipynb | 6 +++--- .../Training-an-RLLIB-MARL-System.ipynb | 2 +- .../notebooks/Using-Episode-Schedules.ipynb | 8 ++++---- src/primaite/session/environment.py | 10 +++++----- src/primaite/session/io.py | 2 +- src/primaite/session/ray_envs.py | 8 ++++---- .../game_layer/test_rewards.py | 6 +++--- 11 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index cd4a1c29..444aa4f7 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: pass -class AgentActionHistoryItem(BaseModel): +class AgentHistoryItem(BaseModel): """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" timestep: int @@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel): response: RequestResponse """The response sent back by the simulator for this action.""" + reward: Optional[float] = None + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -110,7 +112,7 @@ class AbstractAgent(ABC): self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() - self.action_history: List[AgentActionHistoryItem] = [] + self.history: List[AgentHistoryItem] = [] def update_observation(self, state: Dict) -> ObsType: """ @@ -130,7 +132,7 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state=state, last_action_response=self.action_history[-1]) + return self.reward_function.update(state=state, last_action_response=self.history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -161,12 +163,16 @@ class AbstractAgent(ABC): self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse ) -> None: """Process the response from the most recent action.""" - self.action_history.append( - AgentActionHistoryItem( + self.history.append( + AgentHistoryItem( timestep=timestep, action=action, parameters=parameters, request=request, response=response ) ) + def save_reward_to_history(self) -> None: + """Update the most recent history item with the reward value.""" + self.history[-1].reward = self.reward_function.current_reward + class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 0222bfcc..d77640d1 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -34,7 +34,7 @@ from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE if TYPE_CHECKING: - from primaite.game.agent.interface import AgentActionHistoryItem + from primaite.game.agent.interface import AgentHistoryItem _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] @@ -44,7 +44,7 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -64,7 +64,7 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] self._last_request_failed: bool = False - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] self._last_request_failed: bool = False - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -343,7 +343,7 @@ class SharedReward(AbstractReward): self.callback: Callable[[str], float] = default_callback """Method that retrieves an agent's current reward given the agent's name.""" - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it.""" return self.callback(self.agent_name) @@ -389,7 +389,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ea5b3831..772ab5aa 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -160,6 +160,7 @@ class PrimaiteGame: agent = self.agents[agent_name] if self.step_counter > 0: # can't get reward before first action agent.update_reward(state=state) + agent.save_reward_to_history() agent.update_observation(state=state) # order of this doesn't matter so just use reward order agent.reward_function.total_reward += agent.reward_function.current_reward diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index 1b016bb8..21d67bab 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -22,7 +22,7 @@ "# Imports\n", "\n", "from primaite.config.load import data_manipulation_config_path\n", - "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.game.agent.interface import AgentHistoryItem\n", "from primaite.session.environment import PrimaiteGymEnv\n", "import yaml\n", "from pprint import pprint" @@ -63,7 +63,7 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 8104149e..376b7f28 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -392,7 +392,7 @@ "# Imports\n", "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.game.agent.interface import AgentHistoryItem\n", "import yaml\n", "from pprint import pprint\n" ] @@ -444,7 +444,7 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", @@ -705,7 +705,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index 65b1595f..61b988c6 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -25,7 +25,7 @@ "from primaite.game.game import PrimaiteGame\n", "import yaml\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from primaite import PRIMAITE_PATHS\n", "\n", "import ray\n", diff --git a/src/primaite/notebooks/Using-Episode-Schedules.ipynb b/src/primaite/notebooks/Using-Episode-Schedules.ipynb index b0669472..062c7135 100644 --- a/src/primaite/notebooks/Using-Episode-Schedules.ipynb +++ b/src/primaite/notebooks/Using-Episode-Schedules.ipynb @@ -298,8 +298,8 @@ "table = PrettyTable()\n", "table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n", "for i in range(21):\n", - " green_action = env.game.agents['green_A'].action_history[i].action\n", - " red_action = env.game.agents['red_A'].action_history[i].action\n", + " green_action = env.game.agents['green_A'].history[i].action\n", + " red_action = env.game.agents['red_A'].history[i].action\n", " table.add_row([i, green_action, red_action])\n", "print(table)" ] @@ -329,8 +329,8 @@ "table = PrettyTable()\n", "table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n", "for i in range(21):\n", - " green_action = env.game.agents['green_B'].action_history[i].action\n", - " red_action = env.game.agents['red_B'].action_history[i].action\n", + " green_action = env.game.agents['green_B'].history[i].action\n", + " red_action = env.game.agents['red_B'].history[i].action\n", " table.add_row([i, green_action, red_action])\n", "print(table)" ] diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index edb8a476..52edbbb8 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -60,7 +60,7 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = { - "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} + "agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()} } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(step, action, state, reward) @@ -89,8 +89,8 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) self.episode_counter += 1 self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter)) self.game.setup_for_episode(episode=self.episode_counter) @@ -125,5 +125,5 @@ class PrimaiteGymEnv(gymnasium.Env): def close(self): """Close the simulation.""" if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 8bbc1b07..2901457f 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -87,7 +87,7 @@ class PrimaiteIO: """Return the path where agent actions will be saved.""" return self.session_path / "agent_actions" / f"episode_{episode}.json" - def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None: + def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None: """Take the contents of the agent action log and write it to a file. :param episode: Episode number diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 5149a225..6dddde51 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -59,8 +59,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) self.episode_counter += 1 self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) @@ -138,8 +138,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def close(self): """Close the simulation.""" if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) class PrimaiteRayEnv(gymnasium.Env): diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 7c38057e..dff536de 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,6 +1,6 @@ import yaml -from primaite.game.agent.interface import AgentActionHistoryItem +from primaite.game.agent.interface import AgentHistoryItem from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv @@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate( state, - last_action_response=AgentActionHistoryItem( + last_action_response=AgentHistoryItem( timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response ), ) @@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate( state, - last_action_response=AgentActionHistoryItem( + last_action_response=AgentHistoryItem( timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response ), )