2023-05-25 10:52:29 +01:00
|
|
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
2023-06-07 22:57:37 +01:00
|
|
|
import tempfile
|
2023-05-25 14:05:53 +01:00
|
|
|
import time
|
2023-06-07 22:57:37 +01:00
|
|
|
from datetime import datetime
|
2023-05-25 14:05:53 +01:00
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
from primaite.environment.primaite_env import Primaite
|
|
|
|
|
|
|
|
|
|
ACTION_SPACE_NODE_VALUES = 1
|
|
|
|
|
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
|
|
|
|
|
|
|
|
|
|
2023-06-07 22:57:37 +01:00
|
|
|
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
|
|
|
|
"""
|
|
|
|
|
Get a temp directory session path the test session will output to.
|
|
|
|
|
|
|
|
|
|
:param session_timestamp: This is the datetime that the session started.
|
|
|
|
|
:return: The session directory path.
|
|
|
|
|
"""
|
|
|
|
|
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
|
|
|
|
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
|
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir
|
|
|
|
|
session_path.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
|
|
|
|
return session_path
|
|
|
|
|
|
|
|
|
|
|
2023-05-25 14:05:53 +01:00
|
|
|
def _get_primaite_env_from_config(
|
2023-06-08 15:57:38 +01:00
|
|
|
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
2023-05-25 14:05:53 +01:00
|
|
|
):
|
|
|
|
|
"""Takes a config path and returns the created instance of Primaite."""
|
2023-06-07 22:57:37 +01:00
|
|
|
session_timestamp: datetime = datetime.now()
|
|
|
|
|
session_path = _get_temp_session_path(session_timestamp)
|
|
|
|
|
|
|
|
|
|
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
2023-06-07 22:40:16 +01:00
|
|
|
env = Primaite(
|
|
|
|
|
training_config_path=training_config_path,
|
|
|
|
|
lay_down_config_path=lay_down_config_path,
|
|
|
|
|
transaction_list=[],
|
2023-06-07 22:57:37 +01:00
|
|
|
session_path=session_path,
|
|
|
|
|
timestamp_str=timestamp_str,
|
2023-06-07 22:40:16 +01:00
|
|
|
)
|
2023-06-08 15:57:38 +01:00
|
|
|
config_values = env.training_config
|
2023-05-25 14:05:53 +01:00
|
|
|
config_values.num_steps = env.episode_steps
|
|
|
|
|
|
2023-06-07 22:40:16 +01:00
|
|
|
# TOOD: This needs t be refactored to happen outside. Should be part of
|
|
|
|
|
# a main Session class.
|
2023-06-08 15:57:38 +01:00
|
|
|
if env.training_config.agent_identifier == "GENERIC":
|
2023-05-25 14:05:53 +01:00
|
|
|
run_generic(env, config_values)
|
|
|
|
|
|
|
|
|
|
return env
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_generic(env, config_values):
|
|
|
|
|
"""Run against a generic agent."""
|
|
|
|
|
# Reset the environment at the start of the episode
|
|
|
|
|
# env.reset()
|
|
|
|
|
for episode in range(0, config_values.num_episodes):
|
|
|
|
|
for step in range(0, config_values.num_steps):
|
|
|
|
|
# Send the observation space to the agent to get an action
|
|
|
|
|
# TEMP - random action for now
|
|
|
|
|
# action = env.blue_agent_action(obs)
|
2023-06-06 13:49:22 +01:00
|
|
|
# action = env.action_space.sample()
|
|
|
|
|
action = 0
|
2023-05-25 14:05:53 +01:00
|
|
|
|
|
|
|
|
# Run the simulation step on the live environment
|
|
|
|
|
obs, reward, done, info = env.step(action)
|
|
|
|
|
|
|
|
|
|
# Break if done is True
|
|
|
|
|
if done:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Introduce a delay between steps
|
|
|
|
|
time.sleep(config_values.time_delay / 1000)
|
|
|
|
|
|
|
|
|
|
# Reset the environment at the end of the episode
|
|
|
|
|
# env.reset()
|
|
|
|
|
|
|
|
|
|
# env.close()
|