Minor fixes

This commit is contained in:
Marek Wolan
2024-03-14 14:33:04 +00:00
parent f438acf745
commit d33c80d0d6
4 changed files with 25 additions and 7 deletions

View File

@@ -139,8 +139,12 @@ class PrimaiteGame:
"""
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
if self.step_counter == 0:
state = self.get_sim_state()
for agent in self.agents.values():
agent.update_observation(state=state)
# Apply all actions to simulation as requests
action_data = self.apply_agent_actions()
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
@@ -149,7 +153,7 @@ class PrimaiteGame:
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state, and the response from the last action
self.update_agents(state=sim_state, action_data=action_data)
self.update_agents(state=sim_state)
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
@@ -458,6 +462,7 @@ class PrimaiteGame:
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game

View File

@@ -189,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1

View File

@@ -531,4 +531,6 @@ def game_and_agent():
game.agents["test_agent"] = test_agent
game.setup_reward_sharing()
return (game, test_agent)

View File

@@ -1,5 +1,6 @@
import yaml
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
@@ -66,13 +67,18 @@ def test_uc2_rewards(game_and_agent):
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
db_client.apply_request(
response = db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
assert reward_value == 1.0
router.acl.remove_rule(position=2)
@@ -83,7 +89,12 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
assert reward_value == -1.0