diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 84e5e7df..05b76679 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -139,8 +139,12 @@ class PrimaiteGame: """ _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") + if self.step_counter == 0: + state = self.get_sim_state() + for agent in self.agents.values(): + agent.update_observation(state=state) # Apply all actions to simulation as requests - action_data = self.apply_agent_actions() + self.apply_agent_actions() # Advance timestep self.advance_timestep() @@ -149,7 +153,7 @@ class PrimaiteGame: sim_state = self.get_sim_state() # Update agents' observations and rewards based on the current state, and the response from the last action - self.update_agents(state=sim_state, action_data=action_data) + self.update_agents(state=sim_state) def get_sim_state(self) -> Dict: """Get the current state of the simulation.""" @@ -458,6 +462,7 @@ class PrimaiteGame: # Set the NMNE capture config set_nmne_config(network_config.get("nmne_config", {})) + game.update_agents(game.get_sim_state()) return game diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 64534b04..1795f14b 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -189,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" if self.io.settings.save_agent_actions: - self.io.write_agent_actions(episode=self.episode_counter) - self.io.clear_agent_actions() + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 diff --git a/tests/conftest.py b/tests/conftest.py index 20600e73..3a9e2655 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -531,4 +531,6 @@ def game_and_agent(): game.agents["test_agent"] = test_agent + game.setup_reward_sharing() + return (game, test_agent) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 56ba2b8f..cfd013bc 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,5 +1,6 @@ import yaml +from primaite.game.agent.interface import AgentActionHistoryItem from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv @@ -66,13 +67,18 @@ def test_uc2_rewards(game_and_agent): comp = GreenAdminDatabaseUnreachablePenalty("client_1") - db_client.apply_request( + response = db_client.apply_request( [ "execute", ] ) state = game.get_sim_state() - reward_value = comp.calculate(state) + reward_value = comp.calculate( + state, + last_action_response=AgentActionHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) assert reward_value == 1.0 router.acl.remove_rule(position=2) @@ -83,7 +89,12 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate(state) + reward_value = comp.calculate( + state, + last_action_response=AgentActionHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) assert reward_value == -1.0