Merged PR 525: Log observation space data for each episode and step.

## Summary
Updated `AgentHistoryItem` class so that it stores observation space data for every step of each episode. This means that `write_agent_log()` will log the additional data to file provided that `save_agent_actions` is set to `true` in the config file.

## Test process
Tested on following notebooks on Linux and Windows:

- Data-Manipulation-E2E-Demonstration
- Training-an-SB3-Agent
- Training-an-RLLib-Agent
- Training-an-RLLIB-MARL-System.

Wrote and passed new test: `test_obs_data_capture`.
Passes all existing tests.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2845
This commit is contained in:
Nick Todd
2024-09-04 14:26:53 +00:00
4 changed files with 49 additions and 2 deletions

View File

@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Log observation space data by episode and step.
## [3.3.0] - 2024-09-04
### Added
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.

View File

@@ -38,6 +38,9 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
observation: Optional[ObsType] = None
"""The observation space data for this step."""
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -169,12 +172,23 @@ class AbstractAgent(ABC):
return request
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
self,
timestep: int,
action: str,
parameters: Dict[str, Any],
request: RequestFormat,
response: RequestResponse,
observation: ObsType,
) -> None:
"""Process the response from the most recent action."""
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
timestep=timestep,
action=action,
parameters=parameters,
request=request,
response=response,
observation=observation,
)
)

View File

@@ -186,6 +186,7 @@ class PrimaiteGame:
parameters=parameters,
request=request,
response=response,
observation=obs,
)
def pre_timestep(self) -> None:

View File

@@ -0,0 +1,28 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.io import PrimaiteIO
from tests import TEST_ASSETS_ROOT
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
def test_obs_data_in_log_file():
"""Create a log file of AgentHistoryItems and check observation data is
included. Assumes that data_manipulation.yaml has an agent labelled
'defender' with a non-null observation space.
The log file will be in:
primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions
"""
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
env.reset()
for _ in range(10):
env.step(0)
env.reset()
io = PrimaiteIO()
path = io.generate_agent_actions_save_path(episode=1)
with open(path, "r") as f:
j = json.load(f)
assert type(j["0"]["defender"]["observation"]) == dict