Merged PR 360: 2523 - Logging Rewards updates.

## Summary
This PR updates `PrimaiteGymEnv` and `PrimaiteRayMARLEnv` within `environment.py` to log reward values per step and per episode.
Corrects an import statement within `io.py` - removing the use of `src`.
Fixes an error within notebooks that saw them failing following a rename of `game_config` to `env_config`.

## Test process
 - Tests all continue to pass, and notebooks now run without error. This PR does not change any functionality, just adds some logging.
 - Searched for the catch-22 that would trip me up as part of this ticket. A 3 line change felt like too simple an implementation...

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

Related work items: #2523
This commit is contained in:
Charlie Crane
2024-04-30 17:13:50 +00:00
5 changed files with 11 additions and 5 deletions

View File

@@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added the ability for a DatabaseService to terminate a connection.
- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used.
- Added additional show functions to enable connection inspection.
- Updates to agent logging, to include the reward both per step and per episode.
## [Unreleased]

View File

@@ -404,7 +404,7 @@
" # don't flatten observations so that we can see what is going on\n",
" cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",
"env = PrimaiteGymEnv(env_config = cfg)\n",
"obs, info = env.reset()\n",
"print('env created successfully')\n",
"pprint(obs)"

View File

@@ -59,7 +59,7 @@
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game_config=cfg)"
"gym = PrimaiteGymEnv(env_config=cfg)"
]
},
{

View File

@@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
@@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen by the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.pre_timestep()
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
@@ -58,6 +57,7 @@ class PrimaiteGymEnv(gymnasium.Env):
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
reward = self.agent.reward_function.current_reward
_LOGGER.info(f"step: {self.game.step_counter}, Blue reward: {reward}")
terminated = False
truncated = self.game.calculate_truncated()
info = {
@@ -204,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
@@ -244,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger, PRIMAITE_PATHS
from primaite.simulator import LogLevel, SIM_OUTPUT
from src.primaite.utils.primaite_config_utils import is_dev_mode
from primaite.utils.primaite_config_utils import is_dev_mode
_LOGGER = getLogger(__name__)