Minor fixes
This commit is contained in:
@@ -139,8 +139,12 @@ class PrimaiteGame:
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
if self.step_counter == 0:
|
||||
state = self.get_sim_state()
|
||||
for agent in self.agents.values():
|
||||
agent.update_observation(state=state)
|
||||
# Apply all actions to simulation as requests
|
||||
action_data = self.apply_agent_actions()
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
@@ -149,7 +153,7 @@ class PrimaiteGame:
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state, and the response from the last action
|
||||
self.update_agents(state=sim_state, action_data=action_data)
|
||||
self.update_agents(state=sim_state)
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
@@ -458,6 +462,7 @@ class PrimaiteGame:
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
|
||||
@@ -189,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
|
||||
@@ -531,4 +531,6 @@ def game_and_agent():
|
||||
|
||||
game.agents["test_agent"] = test_agent
|
||||
|
||||
game.setup_reward_sharing()
|
||||
|
||||
return (game, test_agent)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
@@ -66,13 +67,18 @@ def test_uc2_rewards(game_and_agent):
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
|
||||
db_client.apply_request(
|
||||
response = db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state)
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
assert reward_value == 1.0
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
@@ -83,7 +89,12 @@ def test_uc2_rewards(game_and_agent):
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state)
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
assert reward_value == -1.0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user