Merged PR 525: Log observation space data for each episode and step.
## Summary Updated `AgentHistoryItem` class so that it stores observation space data for every step of each episode. This means that `write_agent_log()` will log the additional data to file provided that `save_agent_actions` is set to `true` in the config file. ## Test process Tested on following notebooks on Linux and Windows: - Data-Manipulation-E2E-Demonstration - Training-an-SB3-Agent - Training-an-RLLib-Agent - Training-an-RLLIB-MARL-System. Wrote and passed new test: `test_obs_data_capture`. Passes all existing tests. ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2845
This commit is contained in:
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
### Added
|
||||
- Log observation space data by episode and step.
|
||||
|
||||
## [3.3.0] - 2024-09-04
|
||||
### Added
|
||||
- Random Number Generator Seeding by specifying a random number seed in the config file.
|
||||
- Implemented Terminal service class, providing a generic terminal simulation.
|
||||
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
|
||||
|
||||
@@ -38,6 +38,9 @@ class AgentHistoryItem(BaseModel):
|
||||
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
observation: Optional[ObsType] = None
|
||||
"""The observation space data for this step."""
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
@@ -169,12 +172,23 @@ class AbstractAgent(ABC):
|
||||
return request
|
||||
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
self,
|
||||
timestep: int,
|
||||
action: str,
|
||||
parameters: Dict[str, Any],
|
||||
request: RequestFormat,
|
||||
response: RequestResponse,
|
||||
observation: ObsType,
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
timestep=timestep,
|
||||
action=action,
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
observation=observation,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -186,6 +186,7 @@ class PrimaiteGame:
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
observation=obs,
|
||||
)
|
||||
|
||||
def pre_timestep(self) -> None:
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
|
||||
|
||||
|
||||
def test_obs_data_in_log_file():
|
||||
"""Create a log file of AgentHistoryItems and check observation data is
|
||||
included. Assumes that data_manipulation.yaml has an agent labelled
|
||||
'defender' with a non-null observation space.
|
||||
The log file will be in:
|
||||
primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions
|
||||
"""
|
||||
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
|
||||
env.reset()
|
||||
for _ in range(10):
|
||||
env.step(0)
|
||||
env.reset()
|
||||
io = PrimaiteIO()
|
||||
path = io.generate_agent_actions_save_path(episode=1)
|
||||
with open(path, "r") as f:
|
||||
j = json.load(f)
|
||||
|
||||
assert type(j["0"]["defender"]["observation"]) == dict
|
||||
Reference in New Issue
Block a user