diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d08974c..e2989247 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Log observation space data by episode and step. + +## [3.3.0] - 2024-09-04 +### Added - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 14b97821..d5165a71 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -38,6 +38,9 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} + observation: Optional[ObsType] = None + """The observation space data for this step.""" + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -169,12 +172,23 @@ class AbstractAgent(ABC): return request def process_action_response( - self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + self, + timestep: int, + action: str, + parameters: Dict[str, Any], + request: RequestFormat, + response: RequestResponse, + observation: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( AgentHistoryItem( - timestep=timestep, action=action, parameters=parameters, request=request, response=response + timestep=timestep, + action=action, + parameters=parameters, + request=request, + response=response, + observation=observation, ) ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..4f21120d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -186,6 +186,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, + observation=obs, ) def pre_timestep(self) -> None: diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py new file mode 100644 index 00000000..e8bdea22 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -0,0 +1,28 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json + +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.io import PrimaiteIO +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +def test_obs_data_in_log_file(): + """Create a log file of AgentHistoryItems and check observation data is + included. Assumes that data_manipulation.yaml has an agent labelled + 'defender' with a non-null observation space. + The log file will be in: + primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions + """ + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.reset() + for _ in range(10): + env.step(0) + env.reset() + io = PrimaiteIO() + path = io.generate_agent_actions_save_path(episode=1) + with open(path, "r") as f: + j = json.load(f) + + assert type(j["0"]["defender"]["observation"]) == dict