Merged PR 360: 2523 - Logging Rewards updates.
## Summary This PR updates `PrimaiteGymEnv` and `PrimaiteRayMARLEnv` within `environment.py` to log reward values per step and per episode. Corrects an import statement within `io.py` - removing the use of `src`. Fixes an error within notebooks that saw them failing following a rename of `game_config` to `env_config`. ## Test process - Tests all continue to pass, and notebooks now run without error. This PR does not change any functionality, just adds some logging. - Searched for the catch-22 that would trip me up as part of this ticket. A 3 line change felt like too simple an implementation... ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [X] attended to any **TO-DOs** left in the code Related work items: #2523
This commit is contained in:
@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added the ability for a DatabaseService to terminate a connection.
|
||||
- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used.
|
||||
- Added additional show functions to enable connection inspection.
|
||||
- Updates to agent logging, to include the reward both per step and per episode.
|
||||
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
@@ -404,7 +404,7 @@
|
||||
" # don't flatten observations so that we can see what is going on\n",
|
||||
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(game_config = cfg)\n",
|
||||
"env = PrimaiteGymEnv(env_config = cfg)\n",
|
||||
"obs, info = env.reset()\n",
|
||||
"print('env created successfully')\n",
|
||||
"pprint(obs)"
|
||||
|
||||
@@ -59,7 +59,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(game_config=cfg)"
|
||||
"gym = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
@@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen by the RL policy
|
||||
step = self.game.step_counter
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.pre_timestep()
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
@@ -58,6 +57,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
|
||||
reward = self.agent.reward_function.current_reward
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Blue reward: {reward}")
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {
|
||||
@@ -204,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
@@ -244,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
from src.primaite.utils.primaite_config_utils import is_dev_mode
|
||||
from primaite.utils.primaite_config_utils import is_dev_mode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user