Files
PrimAITE/src/primaite/primaite_session.py

217 lines
7.0 KiB
Python
Raw Normal View History

2023-06-13 09:42:54 +01:00
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Final, Optional, Union
from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
return session_path
class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
auto: bool = True
):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._auto: Final[bool] = auto
self._uuid: str = str(uuid4())
self._session_timestamp: Final[datetime] = datetime.now()
self._session_path: Final[Path] = _get_session_path(
self._session_timestamp
)
self._timestamp_str: Final[str] = self._session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
self._metadata_path = self._session_path / "session_metadata.json"
self._env = None
self._training_config = None
self._can_learn: bool = False
_LOGGER.debug("")
if self._auto:
self.setup()
self.learn()
@property
def uuid(self):
"""The session UUID."""
return self._uuid
def _setup_primaite_env(self, transaction_list: Optional[list] = None):
if not transaction_list:
transaction_list = []
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=transaction_list,
session_path=self._session_path,
timestamp_str=self._timestamp_str
)
self._training_config: TrainingConfig = self._env.training_config
def _write_session_metadata_file(self):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self._uuid,
"start_datetime": self._session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": self._env.training_config.to_dict(
json_serializable=True
),
"lay_down_config": self._env.lay_down_config,
},
}
_LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(self._metadata_path, "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._env.episode_count
metadata_dict["total_time_steps"] = self._env.total_step_count
_LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def setup(self):
self._setup_primaite_env()
self._can_learn = True
pass
def learn(
self,
time_steps: Optional[int],
episodes: Optional[int],
iterations: Optional[int],
**kwargs
):
if self._can_learn:
# Run environment against an agent
if self._training_config.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif self._training_config == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif self._training_config == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
def evaluate(
self,
time_steps: Optional[int],
episodes: Optional[int],
**kwargs
):
pass
def export(self):
pass
@classmethod
def import_agent(
cls,
gent_path: str,
training_config_path: str,
lay_down_config_path: str
) -> PrimaiteSession:
session = PrimaiteSession(training_config_path, lay_down_config_path)
# Reset the UUID
session._uuid = ""
return session