diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a96b27c..f79a1fd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added the ability for a DatabaseService to terminate a connection. - Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used. - Added additional show functions to enable connection inspection. +- Updates to agent logging, to include the reward both per step and per episode. ## [Unreleased] diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 2a5ec16f..26283ae9 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -404,7 +404,7 @@ " # don't flatten observations so that we can see what is going on\n", " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", "\n", - "env = PrimaiteGymEnv(game_config = cfg)\n", + "env = PrimaiteGymEnv(env_config = cfg)\n", "obs, info = env.reset()\n", "print('env created successfully')\n", "pprint(obs)" diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 140df1b8..ee51aa58 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -59,7 +59,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game_config=cfg)" + "gym = PrimaiteGymEnv(env_config=cfg)" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index abbf051b..6c42c701 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env): """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" - self.episode_counter: int = 0 """Current episode number.""" @@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen by the RL policy step = self.game.step_counter self.agent.store_action(action) - # apply_agent_actions accesses the action we just stored self.game.pre_timestep() + # apply_agent_actions accesses the action we just stored self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() @@ -58,6 +57,7 @@ class PrimaiteGymEnv(gymnasium.Env): next_obs = self._get_obs() # this doesn't update observation, just gets the current observation reward = self.agent.reward_function.current_reward + _LOGGER.info(f"step: {self.game.step_counter}, Blue reward: {reward}") terminated = False truncated = self.game.calculate_truncated() info = { @@ -204,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} + _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") + if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + self.episode_counter += 1 self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) self.game.setup_for_episode(episode=self.episode_counter) @@ -244,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 4. Get rewards rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + _LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}") terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} infos = {name: {} for name, _ in self.agents.items()} diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 22001fd2..75037381 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger, PRIMAITE_PATHS from primaite.simulator import LogLevel, SIM_OUTPUT -from src.primaite.utils.primaite_config_utils import is_dev_mode +from primaite.utils.primaite_config_utils import is_dev_mode _LOGGER = getLogger(__name__)