From 74821920465a61550728475b1fba3bb6ca32ae55 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Jun 2023 12:01:01 +0100 Subject: [PATCH] #917 - Synced with dev and added better logging --- src/primaite/__init__.py | 47 +++++++++- src/primaite/agents/agent.py | 23 ++--- src/primaite/agents/rllib.py | 3 +- src/primaite/agents/sb3.py | 10 ++- src/primaite/agents/utils.py | 2 +- src/primaite/cli.py | 19 ++++- src/primaite/common/enums.py | 11 +-- .../training/training_config_main.yaml | 19 ++--- src/primaite/config/training_config.py | 19 ++++- src/primaite/environment/primaite_env.py | 85 ++++++++++--------- src/primaite/primaite_session.py | 12 +-- .../setup/_package_data/primaite_config.yaml | 7 +- 12 files changed, 170 insertions(+), 87 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 420420f4..24815727 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -2,9 +2,12 @@ import logging import logging.config import sys -from logging import Logger, StreamHandler +from bisect import bisect +from logging import Formatter, LogRecord, StreamHandler +from logging import Logger from logging.handlers import RotatingFileHandler from pathlib import Path +from typing import Dict from typing import Final import pkg_resources @@ -68,6 +71,33 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. # region Setup Logging +class _LevelFormatter(Formatter): + """ + A custom level-specific formatter. + + Credit to: https://stackoverflow.com/a/68154386 + """ + + def __init__(self, formats: Dict[int, str], **kwargs): + super().__init__() + + if "fmt" in kwargs: + raise ValueError( + "Format string must be passed to level-surrogate formatters, " + "not this one" + ) + + self.formats = sorted( + (level, Formatter(fmt, **kwargs)) for level, fmt in formats.items() + ) + + def format(self, record: LogRecord) -> str: + """Overrides ``Formatter.format``.""" + idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1) + level, formatter = self.formats[idx] + return formatter.format(record) + + def _log_dir() -> Path: if sys.platform == "win32": dir_path = _PLATFORM_DIRS.user_data_path / "logs" @@ -76,6 +106,16 @@ def _log_dir() -> Path: return dir_path +_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( + { + logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"], + logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"], + logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"], + logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"], + logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"] + } +) + LOG_DIR: Final[Path] = _log_dir() """The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS.""" @@ -85,6 +125,7 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log" """The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS.""" _STREAM_HANDLER: Final[StreamHandler] = StreamHandler() + _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( filename=LOG_PATH, maxBytes=10485760, # 10MB @@ -95,8 +136,8 @@ _STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) _FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) _LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"] -_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) -_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) +_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) +_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) _LOGGER = logging.getLogger(__name__) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 812072ba..5f4dac8f 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -3,11 +3,11 @@ import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Optional, Final, Dict, Union, List +from typing import Optional, Final, Dict, Union from uuid import uuid4 +import primaite from primaite import getLogger, SESSIONS_DIR -from primaite.common.enums import OutputVerboseLevel from primaite.config import lay_down_config from primaite.config import training_config from primaite.config.training_config import TrainingConfig @@ -141,14 +141,13 @@ class AgentSessionABC(ABC): @abstractmethod def _setup(self): - if self.output_verbose_level >= OutputVerboseLevel.INFO: - _LOGGER.info( - "Welcome to the Primary-level AI Training Environment " - "(PrimAITE)" - ) - _LOGGER.debug( - f"The output directory for this agent is: {self.session_path}" - ) + _LOGGER.info( + "Welcome to the Primary-level AI Training Environment " + f"(PrimAITE) (version: {primaite.__version__})" + ) + _LOGGER.info( + f"The output directory for this session is: {self.session_path}" + ) self._write_session_metadata_file() self._can_learn = True self._can_evaluate = False @@ -165,6 +164,7 @@ class AgentSessionABC(ABC): **kwargs ): if self._can_learn: + _LOGGER.info("Finished learning") _LOGGER.debug("Writing transactions") self._update_session_metadata_file() self._can_evaluate = True @@ -176,7 +176,7 @@ class AgentSessionABC(ABC): episodes: Optional[int] = None, **kwargs ): - pass + _LOGGER.info("Finished evaluation") @abstractmethod def _get_latest_checkpoint(self): @@ -260,6 +260,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): # Introduce a delay between steps time.sleep(self._training_config.time_delay / 1000) self._env.close() + super().evaluate() @classmethod def load(cls): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index b4b0ec56..710225d7 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -152,7 +152,8 @@ class RLlibAgent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - + _LOGGER.info(f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 073eb2fe..4d2ded6b 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -69,8 +69,9 @@ class SB3Agent(AgentSessionABC): (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes) ): - self._agent.save( - self.checkpoints_path / f"sb3ppo_{episode_count}.zip") + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -86,6 +87,8 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes + _LOGGER.info(f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps...") for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() @@ -106,6 +109,9 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes + _LOGGER.info(f"Beginning evaluation for {episodes} episodes @" + f" {time_steps} time steps...") + for episode in range(episodes): obs = self._env.reset() diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index acc71590..a4eadc3b 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -6,8 +6,8 @@ from primaite.common.enums import ( NodeHardwareAction, NodeSoftwareAction, SoftwareState, + NodePOLType ) -from primaite.common.enums import NodePOLType def transform_action_node_readable(action): diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 319d643c..aa88a391 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -160,13 +160,26 @@ def setup(overwrite_existing: bool = True): @app.command() -def session(tc: str, ldc: str): +def session(tc: Optional[str] = None, ldc: Optional[str] = None): """ Run a PrimAITE session. - :param tc: The training config filepath. - :param ldc: The lay down config file path. + tc: The training config filepath. Optional. If no value is passed then + example default training config is used from: + ~/primaite/config/example_config/training/training_config_main.yaml. + + ldc: The lay down config file path. Optional. If no value is passed then + example default lay down config is used from: + ~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml. """ from primaite.main import run + from primaite.config.training_config import main_training_config_path + from primaite.config.lay_down_config import data_manipulation_config_path + + if not tc: + tc = main_training_config_path() + + if not ldc: + ldc = data_manipulation_config_path() run(training_config_path=tc, lay_down_config_path=ldc) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 191cb782..6a93e1b5 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -83,9 +83,12 @@ class Protocol(Enum): class SessionType(Enum): """The type of PrimAITE Session to be run.""" - TRAINING = 1 - EVALUATION = 2 - BOTH = 3 + TRAIN = 1 + "Train an agent" + EVAL = 2 + "Evaluate an agent" + TRAIN_EVAL = 3 + "Train then evaluate an agent" class VerboseLevel(IntEnum): @@ -141,7 +144,6 @@ class HardCodedAgentView(Enum): class ActionType(Enum): """Action type enumeration.""" - NODE = 0 ACL = 1 ANY = 2 @@ -149,7 +151,6 @@ class ActionType(Enum): class ObservationType(Enum): """Observation type enumeration.""" - BOX = 0 MULTIDISCRETE = 1 diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 0f99a501..9cbcb702 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -38,26 +38,23 @@ hard_coded_agent_view: FULL action_type: ANY # Number of episodes to run per session -num_episodes: 100 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 # Sets how often the agent will save a checkpoint (every n time episodes). # Set to 0 if no checkpoints are required. Default is 10 -checkpoint_every_n_episodes: 100 +checkpoint_every_n_episodes: 10 # Time delay between steps (for generic agents) -time_delay: 3 +time_delay: 5 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING - -# Determine whether to load an agent from file -load_agent: False - -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN # Environment config values # The high value for the observation space diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 020d5b03..72b5523a 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -69,7 +69,7 @@ class TrainingConfig: "The delay between steps (ms). Applies to generic agents only" # file - session_type: SessionType = SessionType.TRAINING + session_type: SessionType = SessionType.TRAIN "The type of PrimAITE session to run" load_agent: str = False @@ -171,6 +171,7 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system" + @classmethod def from_dict( cls, @@ -183,7 +184,7 @@ class TrainingConfig: "action_type": ActionType, "session_type": SessionType, "output_verbose_level": OutputVerboseLevel, - "hard_coded_agent_view": HardCodedAgentView + "hard_coded_agent_view": HardCodedAgentView, } for field, enum_class in field_enum_map.items(): @@ -211,6 +212,20 @@ class TrainingConfig: return data + def __str__(self) -> str: + tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, " + if self.agent_framework is AgentFramework.RLLIB: + tc += f"deep_learning_framework=" \ + f"{self.deep_learning_framework.name}, " + tc += f"agent_identifier={self.agent_identifier.name}, " + if self.agent_identifier is AgentIdentifier.HARDCODED: + tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, " + tc += f"action_type={self.action_type.name}, " + tc += f"observation_space={self.observation_space}, " + tc += f"num_episodes={self.num_episodes}, " + tc += f"num_steps={self.num_steps})" + return tc + def load( file_path: Union[str, Path], diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 44f576ce..5319d0f1 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -15,7 +15,8 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, \ + is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -36,13 +37,15 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import \ + NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, \ + apply_red_agent_node_pol from primaite.transactions.transaction import Transaction from primaite.transactions.transactions_to_file import \ write_transaction_to_file @@ -61,12 +64,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -86,6 +89,7 @@ class Primaite(Env): self.training_config: TrainingConfig = training_config.load( training_config_path ) + _LOGGER.info(f"Using: {str(self.training_config)}") # Number of steps in an episode self.episode_steps = self.training_config.num_steps @@ -207,16 +211,14 @@ class Primaite(Env): plt.savefig(file_path, format="PNG") plt.clf() except Exception: - _LOGGER.error("Could not save network diagram") - _LOGGER.error("Exception occured", exc_info=True) - print("Could not save network diagram") + _LOGGER.error("Could not save network diagram", exc_info=True) # Initiate observation space self.observation_space, self.env_obs = self.init_observations() # Define Action Space - depends on action space type (Node or ACL) if self.training_config.action_type == ActionType.NODE: - _LOGGER.info("Action space type NODE selected") + _LOGGER.debug("Action space type NODE selected") # Terms (for node action space): # [0, num nodes] - node ID (0 = nothing, node ID) # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa @@ -225,7 +227,7 @@ class Primaite(Env): self.action_dict = self.create_node_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) elif self.training_config.action_type == ActionType.ACL: - _LOGGER.info("Action space type ACL selected") + _LOGGER.debug("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) # [0, 1] - Permission (0 = DENY, 1 = ALLOW) @@ -236,11 +238,11 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) elif self.training_config.action_type == ActionType.ANY: - _LOGGER.info("Action space type ANY selected - Node + ACL") + _LOGGER.debug("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info( + _LOGGER.error( f"Invalid action type selected: {self.training_config.action_type}" ) # Set up a csv to store the results of the training @@ -301,17 +303,14 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ - # Introduce a delay between steps - time.sleep(self.training_config.time_delay / 1000) if self.step_count == 0: - print(f"Episode: {str(self.episode_count)}") + _LOGGER.info(f"Episode: {str(self.episode_count)}") # TEMP done = False self.step_count += 1 self.total_step_count += 1 - # print("Episode step: " + str(self.step_count)) # Need to clear traffic on all links first for link_key, link_value in self.links.items(): @@ -322,7 +321,8 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction( - datetime.now(), self.agent_identifier, self.episode_count, self.step_count + datetime.now(), self.agent_identifier, self.episode_count, + self.step_count ) # Load the initial observation space into the transaction transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) @@ -352,7 +352,8 @@ class Primaite(Env): self.nodes_post_pol = copy.deepcopy(self.nodes) self.links_post_pol = copy.deepcopy(self.links) # Reference - apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL + apply_node_pol(self.nodes_reference, self.node_pol, + self.step_count) # Node PoL apply_iers( self.network_reference, self.nodes_reference, @@ -389,7 +390,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + _LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -397,7 +398,7 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - print(f" Average Reward: {str(self.average_reward)}") + _LOGGER.info(f" Average Reward: {str(self.average_reward)}") # Load the reward into the transaction transaction.set_reward(reward) @@ -428,6 +429,7 @@ class Primaite(Env): self.timestamp_str ) self.csv_file.close() + def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() @@ -435,9 +437,9 @@ class Primaite(Env): def output_link_status(self): """Output the link status of all links to the console.""" for link_key, link_value in self.links.items(): - print("Link ID: " + link_value.get_id()) + _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - print( + _LOGGER.debug( " Protocol: " + protocol.get_name().name + ", Load: " @@ -457,11 +459,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -529,7 +531,8 @@ class Primaite(Env): elif property_action == 1: # Patch (valid action if it's good or compromised) node.set_service_state( - self.services_list[service_index], SoftwareState.PATCHING + self.services_list[service_index], + SoftwareState.PATCHING ) else: # Node is not of Service Type @@ -589,7 +592,8 @@ class Primaite(Env): acl_rule_source = "ANY" else: node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): + if isinstance(node, ServiceNode) or isinstance(node, + ActiveNode): acl_rule_source = node.ip_address else: return @@ -598,7 +602,8 @@ class Primaite(Env): acl_rule_destination = "ANY" else: node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): + if isinstance(node, ServiceNode) or isinstance(node, + ActiveNode): acl_rule_destination = node.ip_address else: return @@ -683,7 +688,8 @@ class Primaite(Env): :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] """ - self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + self.obs_handler = ObservationsHandler.from_config(self, + self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation @@ -727,8 +733,7 @@ class Primaite(Env): _LOGGER.error(f"Invalid item_type: {item_type}") pass - _LOGGER.info("Environment configuration loaded") - print("Environment configuration loaded") + _LOGGER.debug("Environment configuration loaded") def create_node(self, item): """ @@ -791,7 +796,8 @@ class Primaite(Env): service_protocol = service["name"] service_port = service["port"] service_state = SoftwareState[service["state"]] - node.add_service(Service(service_protocol, service_port, service_state)) + node.add_service( + Service(service_protocol, service_port, service_state)) else: # Bad formatting pass @@ -844,7 +850,8 @@ class Primaite(Env): dest_node_ref: Node = self.nodes_reference[link_destination] # Add link to network (reference) - self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name) + self.network_reference.add_edge(source_node_ref, dest_node_ref, + id=link_name) # Add link to link dictionary (reference) self.links_reference[link_name] = Link( @@ -1119,7 +1126,8 @@ class Primaite(Env): # All nodes have these parameters node_id = item["node_id"] node_class = item["node_class"] - node_hardware_state: HardwareState = HardwareState[item["hardware_state"]] + node_hardware_state: HardwareState = HardwareState[ + item["hardware_state"]] node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] @@ -1185,7 +1193,8 @@ class Primaite(Env): # Use MAX to ensure we get them all for node_action in range(4): for service_state in range(self.num_services): - action = [node, node_property, node_action, service_state] + action = [node, node_property, node_action, + service_state] # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): actions[action_key] = action diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 70a18a4b..cd959be0 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -47,8 +47,7 @@ class PrimaiteSession: def __init__( self, training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - auto: bool = True + lay_down_config_path: Union[str, Path] ): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) @@ -64,13 +63,8 @@ class PrimaiteSession: self._lay_down_config_path ) - self._auto: bool = auto self._agent_session: AgentSessionABC = None # noqa - if self._auto: - self.setup() - self.learn() - def setup(self): if self._training_config.agent_framework == AgentFramework.CUSTOM: if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: @@ -157,7 +151,7 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - if not self._training_config.session_type == SessionType.EVALUATION: + if not self._training_config.session_type == SessionType.EVAL: self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( @@ -166,5 +160,5 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - if not self._training_config.session_type == SessionType.TRAINING: + if not self._training_config.session_type == SessionType.TRAIN: self._agent_session.evaluate(time_steps, episodes, **kwargs) diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 690544fb..5d469ffe 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -2,4 +2,9 @@ # Logging log_level: INFO -logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' +logger_format: + DEBUG: '%(asctime)s: %(message)s' + INFO: '%(asctime)s: %(message)s' + WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'