from __future__ import annotations import json from datetime import datetime from pathlib import Path from typing import Final, Optional, Union from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) def _get_session_path(session_timestamp: datetime) -> Path: """ Get the directory path the session will output to. This is set in the format of: ~/primaite/sessions//_. :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ date_dir = session_timestamp.strftime("%Y-%m-%d") session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") session_path = SESSIONS_DIR / date_dir / session_dir session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") return session_path class PrimaiteSession: def __init__( self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], auto: bool = True ): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path self._auto: Final[bool] = auto self._uuid: str = str(uuid4()) self._session_timestamp: Final[datetime] = datetime.now() self._session_path: Final[Path] = _get_session_path( self._session_timestamp ) self._timestamp_str: Final[str] = self._session_timestamp.strftime( "%Y-%m-%d_%H-%M-%S") self._metadata_path = self._session_path / "session_metadata.json" self._env = None self._training_config = None self._can_learn: bool = False _LOGGER.debug("") if self._auto: self.setup() self.learn() @property def uuid(self): """The session UUID.""" return self._uuid def _setup_primaite_env(self, transaction_list: Optional[list] = None): if not transaction_list: transaction_list = [] self._env: Primaite = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, transaction_list=transaction_list, session_path=self._session_path, timestamp_str=self._timestamp_str ) self._training_config: TrainingConfig = self._env.training_config def _write_session_metadata_file(self): """ Write the ``session_metadata.json`` file. Creates a ``session_metadata.json`` in the ``session_dir`` directory and adds the following key/value pairs: - uuid: The UUID assigned to the session upon instantiation. - start_datetime: The date & time the session started in iso format. - end_datetime: NULL. - total_episodes: NULL. - total_time_steps: NULL. - env: - training_config: - All training config items - lay_down_config: - All lay down config items """ metadata_dict = { "uuid": self._uuid, "start_datetime": self._session_timestamp.isoformat(), "end_datetime": None, "total_episodes": None, "total_time_steps": None, "env": { "training_config": self._env.training_config.to_dict( json_serializable=True ), "lay_down_config": self._env.lay_down_config, }, } _LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}") with open(self._metadata_path, "w") as file: json.dump(metadata_dict, file) def _update_session_metadata_file(self): """ Update the ``session_metadata.json`` file. Updates the `session_metadata.json`` in the ``session_dir`` directory with the following key/value pairs: - end_datetime: NULL. - total_episodes: NULL. - total_time_steps: NULL. """ with open(self._metadata_path, "r") as file: metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() metadata_dict["total_episodes"] = self._env.episode_count metadata_dict["total_time_steps"] = self._env.total_step_count _LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}") with open(self._metadata_path, "w") as file: json.dump(metadata_dict, file) def setup(self): self._setup_primaite_env() self._can_learn = True pass def learn( self, time_steps: Optional[int], episodes: Optional[int], iterations: Optional[int], **kwargs ): if self._can_learn: # Run environment against an agent if self._training_config.agent_identifier == "GENERIC": run_generic(env=env, config_values=config_values) elif self._training_config == "STABLE_BASELINES3_PPO": run_stable_baselines3_ppo( env=env, config_values=config_values, session_path=session_dir, timestamp_str=timestamp_str, ) elif self._training_config == "STABLE_BASELINES3_A2C": run_stable_baselines3_a2c( env=env, config_values=config_values, session_path=session_dir, timestamp_str=timestamp_str, ) print("Session finished") _LOGGER.debug("Session finished") print("Saving transaction logs...") write_transaction_to_file( transaction_list=transaction_list, session_path=session_dir, timestamp_str=timestamp_str, ) print("Updating Session Metadata file...") _update_session_metadata_file(session_dir=session_dir, env=env) print("Finished") _LOGGER.debug("Finished") def evaluate( self, time_steps: Optional[int], episodes: Optional[int], **kwargs ): pass def export(self): pass @classmethod def import_agent( cls, gent_path: str, training_config_path: str, lay_down_config_path: str ) -> PrimaiteSession: session = PrimaiteSession(training_config_path, lay_down_config_path) # Reset the UUID session._uuid = "" return session