From f75c10aafb40f973d22e666f0ee7538b7e0e8676 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 5 Jan 2024 13:10:49 +0000 Subject: [PATCH] Make flattening observation spaces optional. --- CHANGELOG.md | 3 +- docs/source/config.rst | 24 +- .../config/_package_data/example_config.yaml | 2 +- src/primaite/game/agent/interface.py | 3 + src/primaite/game/game.py | 1 + .../training_example_ray_single_agent.ipynb | 221 +++++++++++++++++- src/primaite/session/environment.py | 15 +- src/primaite/session/session.py | 4 +- 8 files changed, 259 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e44efe3..c712ef66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] -- Made packet capture and system logging optional +- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config. +- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config. ### Added diff --git a/docs/source/config.rst b/docs/source/config.rst index f4452c7e..23bf6097 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -13,7 +13,25 @@ This section allows selecting which training framework and algorithm to use, and ``io_settings`` --------------- -This section configures how the ``PrimaiteSession`` saves data. +This section configures how PrimAITE saves data during simulation and training. + +**save_final_model**: Only used if training with PrimaiteSession, if true, the policy will be saved after the final training iteration. + +**save_checkpoints**: Only used if training with PrimaiteSession, if true, the policy will be saved periodically during training. + +**checkpoint_interval**: Only used if training with PrimaiteSession and if ``save_checkpoints`` is true. Defines how often to save the policy during training. + +**save_logs**: *currently unused*. + +**save_transactions**: *currently unused*. + +**save_tensorboard_logs**: *currently unused*. + +**save_step_metadata**: Whether to save the RL agents' action, environment state, and other data at every single step. + +**save_pcap_logs**: Whether to save pcap files of all network traffic during the simulation. + +**save_sys_logs**: Whether to save system logs from all nodes during the simulation. ``game`` -------- @@ -56,6 +74,10 @@ Description of configurable items: **agent_settings**: Settings passed to the agent during initialisation. These depend on the agent class. +Reinforcement learning agents use the ``ProxyAgent`` class, they accept these agent settings: + +**flatten_obs**: If true, gymnasium flattening will be performed on the observation space before sending to the agent. Set this to true if your agent does not support nested observation spaces. + ``simulation`` -------------- In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 81c9643e..2ac23661 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -525,7 +525,7 @@ agents: agent_settings: - # ... + flatten_obs: true diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index fbbe5473..8657fc45 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -44,6 +44,7 @@ class AgentSettings(BaseModel): start_settings: Optional[AgentStartSettings] = None "Configuration for when an agent begins performing it's actions" + flatten_obs: bool = True @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": @@ -166,6 +167,7 @@ class ProxyAgent(AbstractAgent): action_space: Optional[ActionManager], observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], + agent_settings: Optional[AgentSettings] = None, ) -> None: super().__init__( agent_name=agent_name, @@ -174,6 +176,7 @@ class ProxyAgent(AbstractAgent): reward_function=reward_function, ) self.most_recent_action: ActType + self.flatten_obs: bool = agent_settings.flatten_obs def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index d2db4bea..586bca79 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -427,6 +427,7 @@ class PrimaiteGame: action_space=action_space, observation_space=obs_space, reward_function=rew_function, + agent_settings=agent_settings, ) game.agents.append(new_agent) game.rl_agents.append(new_agent) diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index a89b29e4..993e81ff 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -10,9 +10,64 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2024-01-05 12:46:28,650\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2024-01-05 12:46:31,581\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2024-01-05 12:46:31,903\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n", + "2024-01-05 12:46:35,016\tINFO worker.py:1673 -- Started a local Ray instance.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Python version:3.10.12
Ray version:2.8.0
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + "RayContext(dashboard_url='', python_version='3.10.12', ray_version='2.8.0', ray_commit='105355bd253d6538ed34d331f6a4bdf0e38ace3a', protocol_version=None)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from primaite.game.game import PrimaiteGame\n", "import yaml\n", @@ -41,7 +96,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'flatten_obs': False}\n" + ] + } + ], + "source": [ + "print(cfg['agents'][2]['agent_settings'])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -64,9 +136,141 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":job_id:01000000\n", + ":task_name:bundle_reservation_check_func\n", + ":actor_name:PPO\n", + "2024-01-05 12:46:40,174: Added service 6589f0e3-f427-4382-9e29-1344624bfe33 to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,175: Added service c3299cad-9a05-4fa1-bf6b-3fb82e9c9717 to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,176: Added application 2cba2c28-dc88-4688-b8bc-ac5cf2b7652c to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,178: Added service 23207441-15d6-42f8-bae4-1810bbff7d8b to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,179: Added service 01a61864-52d1-4980-b88c-6bec7ad58c3e to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,180: Added application a51c4db0-028c-440c-845e-43f7a1354544 to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,181: Added service 7beff81f-6083-421b-a212-e02d9eb3ad69 to node 6f2396a1-80f4-4822-8d98-42811a89521e\n", + "2024-01-05 12:46:40,184: Added service e49fd236-0195-4571-a992-af490c2d27c4 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,186: Added service 9fdc6bb7-a338-4a64-b7a3-8467c88f79fd to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,188: Added application 3f99407c-1642-47e5-ade5-5106a1b49004 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,189: Added service 0e4c4e77-1bbb-45c3-aa4b-fdfd6c439091 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,190: Added service 711608ae-5f71-4bb7-8c99-95974f28f964 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,191: Added application bfbe2fb3-4d7e-4f07-9454-aee8404ca4b3 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,192: Added application 2cd88860-c7c5-4e64-b07a-4f0c9a0d8324 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,194: Added service 3cafdb32-3a89-4ab4-a22c-00beb29d6e71 to node 5a8e7052-0094-4104-aedb-beda65db2214\n", + "2024-01-05 12:46:40,196: Added service 649ff374-b9b3-4f17-94de-d95472cc94be to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,198: Added service 561374dc-8844-4a71-a577-67659130afaf to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,200: Added application 14eb20b8-ea9e-4027-a9ef-bf438b1f2b5e to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,202: Added service c7721159-10ad-4fd1-9fc7-a4403f89743a to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,203: Added service 907aff5d-c7d3-4d23-ab97-3bdaf92c8707 to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,204: Added application c8a55900-00af-46a7-90b5-bf8591130534 to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,206: Added service 9ae26c20-4c51-4283-b791-3c278c85aaef to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,207: Added service d3f108af-6a58-430b-9fc8-495e7db16968 to node 0053cdf7-44aa-4a44-b71c-a0351927e797\n", + "2024-01-05 12:46:40,211: Added service b759a0a5-7fe9-4f29-830e-6c50fe3d5ac0 to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,212: Added service d07213b5-d35b-4343-96ff-76399f80d12c to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,213: Added application f4cb45da-c81c-4fbf-adcf-461ca8728576 to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,215: Added service 44dadb4d-09b2-4569-97ed-18ed5e050437 to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,216: Added service 6c2e121a-fe1e-45fd-b0d4-587c0f6aafba to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,217: Added application e1ed96b9-221a-4a26-8330-1142f7681bf3 to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,218: Added service 4a9b52fb-747f-4921-bd73-2ee17557b2de to node 92240e65-db56-4b90-a1e3-a0e7d0d7e9a6\n", + "2024-01-05 12:46:40,220: Added service 38f3dfa9-6974-4122-b731-63a9cc3a13b2 to node 0c55c8bd-252b-420e-8ba6-e81091c21ff9\n", + "2024-01-05 12:46:40,220: Added service 5e2b34f4-9ac6-4e9d-b2db-48aac4eeff32 to node 0c55c8bd-252b-420e-8ba6-e81091c21ff9\n", + "2024-01-05 12:46:40,221: Added application 2db51ce9-391f-4e82-acf6-b565819b6c6d to node 0c55c8bd-252b-420e-8ba6-e81091c21ff9\n", + "2024-01-05 12:46:40,223: Added service e33f7cfb-6940-4076-9a2f-5874ba385c57 to node 0c55c8bd-252b-420e-8ba6-e81091c21ff9\n", + "2024-01-05 12:46:40,224: Added service 346687ac-a032-479b-9ccb-ab2df7d5b84b to node 0c55c8bd-252b-420e-8ba6-e81091c21ff9\n", + "2024-01-05 12:46:40,225: Added application 7adcddf8-4d1f-428b-8722-7ce44f1e64d7 to node 0c55c8bd-252b-420e-8ba6-e81091c21ff9\n", + "2024-01-05 12:46:40,229: Added service c498af8f-5648-4340-b117-7dd958d8bccb to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,231: Added service 0218fc4e-fb25-47fb-b03c-9ecfa90986eb to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,232: Added application edfe50af-01ac-45e4-8b96-fdb5ec6d61d7 to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,233: Added service fb4b25f9-4a3f-41ec-a2db-57eac73201a9 to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,234: Added service 062e3e3e-65a4-4a30-ad34-418bdc2f5886 to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,235: Added application 72cdbee1-3ed9-4189-8198-788bfacacb44 to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,236: Added service f5b741a0-25a5-42cb-86e8-e42b2fec7433 to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,237: Added application 3fc736ef-9308-49a6-b63c-559ec878fc30 to node 88b3a7a8-bd87-48e3-b74b-5a66d2d770eb\n", + "2024-01-05 12:46:40,240: Added service 463f9765-6d2b-427c-9319-f4af92de3815 to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,241: Added service c4e6a1fc-7512-45e4-b8c1-60d0913b22d3 to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,242: Added application c1f981ba-db6b-4a1e-98a8-d8f0efb0ead9 to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,244: Added service 815c2d28-75e5-4ea6-a283-a18b389d4e6e to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,246: Added service c9ee86ec-9490-4fd8-8aef-6cab5c9e7c67 to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,247: Added application 91763a2f-8135-4fd3-a4cc-81c883d0d33f to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,248: Added service c71b0658-1e43-4881-b603-87dbf1c171a2 to node 65549a8a-9788-462b-9a12-883747b73a3b\n", + "2024-01-05 12:46:40,249: Added application 39ad54a1-28cd-4cd9-88d6-32e3f00844ae to node 65549a8a-9788-462b-9a12-883747b73a3b\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ":job_id:01000000\n", + ":task_name:bundle_reservation_check_func\n", + ":actor_name:PPO\n", + "Resetting environment, episode -1, avg. reward: 0.0\n", + ":actor_name:PPO\n", + "Resetting environment, episode 0, avg. reward: 0.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":actor_name:PPO\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Resetting environment, episode 1, avg. reward: -101.0\n", + "Resetting environment, episode 2, avg. reward: -126.5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-01-05 12:46:46,735\tINFO storage.py:563 -- Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/cade/ray_results/PPO_2024-01-05_12-46-39/PPO_PrimaiteRayEnv_7899c_00000_0_2024-01-05_12-46-40/checkpoint_000000)\n", + "2024-01-05 12:46:46,847\tINFO tune.py:1047 -- Total run time: 6.85 seconds (6.77 seconds for the tuning loop).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "ResultGrid<[\n", + " Result(\n", + " metrics={'custom_metrics': {}, 'episode_media': {}, 'info': {'learner': {'__all__': {'num_agent_steps_trained': 128.0, 'num_env_steps_trained': 128.0, 'total_loss': 9.46448793411255}, 'default_policy': {'total_loss': 9.46448793411255, 'policy_loss': -0.06344481200600664, 'vf_loss': 9.525621096293131, 'vf_loss_unclipped': 509.6841542561849, 'vf_explained_var': 0.004743536313374837, 'entropy': 3.855365761121114, 'mean_kl_loss': 0.011559122211959523, 'default_optimizer_lr': 4.999999999999999e-05, 'curr_lr': 5e-05, 'curr_entropy_coeff': 0.0, 'curr_kl_coeff': 0.20000000298023224}}, 'num_env_steps_sampled': 512, 'num_env_steps_trained': 0, 'num_agent_steps_sampled': 512, 'num_agent_steps_trained': 0}, 'sampler_results': {'episode_reward_max': -101.0, 'episode_reward_min': -126.5, 'episode_reward_mean': -113.75, 'episode_len_mean': 256.0, 'episode_media': {}, 'episodes_this_iter': 1, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-101.0, -126.5], 'episode_lengths': [256, 256]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 1.4790121096531605, 'mean_inference_ms': 2.438426005102573, 'mean_action_processing_ms': 0.12985192105746982, 'mean_env_wait_ms': 2.5040151965470656, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 1.051938533782959, 'StateBufferConnector_ms': 0.009810924530029297, 'ViewRequirementAgentConnector_ms': 0.46378374099731445}}, 'episode_reward_max': -101.0, 'episode_reward_min': -126.5, 'episode_reward_mean': -113.75, 'episode_len_mean': 256.0, 'episodes_this_iter': 1, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'hist_stats': {'episode_reward': [-101.0, -126.5], 'episode_lengths': [256, 256]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 1.4790121096531605, 'mean_inference_ms': 2.438426005102573, 'mean_action_processing_ms': 0.12985192105746982, 'mean_env_wait_ms': 2.5040151965470656, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 1.051938533782959, 'StateBufferConnector_ms': 0.009810924530029297, 'ViewRequirementAgentConnector_ms': 0.46378374099731445}, 'num_healthy_workers': 0, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0, 'num_agent_steps_sampled': 512, 'num_agent_steps_trained': 0, 'num_env_steps_sampled': 512, 'num_env_steps_trained': 0, 'num_env_steps_sampled_this_iter': 128, 'num_env_steps_trained_this_iter': 0, 'num_env_steps_sampled_throughput_per_sec': 57.18780249848288, 'num_env_steps_trained_throughput_per_sec': 0.0, 'num_steps_trained_this_iter': 0, 'agent_timesteps_total': 512, 'timers': {'training_iteration_time_ms': 1392.194, 'sample_time_ms': 995.05, 'synch_weights_time_ms': 1.92}, 'counters': {'num_env_steps_sampled': 512, 'num_env_steps_trained': 0, 'num_agent_steps_sampled': 512, 'num_agent_steps_trained': 0}, 'perf': {'cpu_util_percent': 54.06666666666666, 'ram_util_percent': 53.53333333333334}},\n", + " path='/home/cade/ray_results/PPO_2024-01-05_12-46-39/PPO_PrimaiteRayEnv_7899c_00000_0_2024-01-05_12-46-40',\n", + " filesystem='local',\n", + " checkpoint=Checkpoint(filesystem=local, path=/home/cade/ray_results/PPO_2024-01-05_12-46-39/PPO_PrimaiteRayEnv_7899c_00000_0_2024-01-05_12-46-40/checkpoint_000000)\n", + " )\n", + "]>" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "tune.Tuner(\n", " \"PPO\",\n", @@ -76,6 +280,13 @@ " param_space=config\n", ").fit()\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index ca71a0c0..36ab3f58 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -23,6 +23,7 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.game: "PrimaiteGame" = game self.agent: ProxyAgent = self.game.rl_agents[0] + self.flatten_obs: bool = False def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -81,13 +82,19 @@ class PrimaiteGymEnv(gymnasium.Env): @property def observation_space(self) -> gymnasium.Space: """Return the observation space of the environment.""" - return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + if self.agent.flatten_obs: + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + else: + return self.agent.observation_manager.space def _get_obs(self) -> ObsType: """Return the current observation.""" - unflat_space = self.agent.observation_manager.space - unflat_obs = self.agent.observation_manager.current_observation - return gymnasium.spaces.flatten(unflat_space, unflat_obs) + if not self.agent.flatten_obs: + return self.agent.observation_manager.current_observation + else: + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) class PrimaiteRayEnv(gymnasium.Env): diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 0197ac9d..5c663cfd 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -101,9 +101,9 @@ class PrimaiteSession: # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"game": game}) + sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"game": game}) + sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) elif sess.training_options.rl_framework == "SB3": sess.env = PrimaiteGymEnv(game=game)