From 31703c54e25a05e8aa02a18a223f5e73496c1a1d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 10 Jul 2023 14:56:06 +0100 Subject: [PATCH] Finished writing custom agent example. --- docs/source/custom_agent.rst | 106 ++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 51 deletions(-) diff --git a/docs/source/custom_agent.rst b/docs/source/custom_agent.rst index 74b6a607..45d1c5a4 100644 --- a/docs/source/custom_agent.rst +++ b/docs/source/custom_agent.rst @@ -2,14 +2,21 @@ ============= -**Integrating a user defined blue agent** +Integrating a user defined blue agent +************************************* -PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, ``save()``, and ``export()`` methods. +.. note:: + + If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported. + +PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods. Below is a barebones example of a custom agent implementation: .. code:: python + # src/primaite/agents/my_custom_agent.py + from primaite.agents.agent import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier @@ -63,72 +70,69 @@ Below is a barebones example of a custom agent implementation: @classmethod def load(cls, path): ... - # + # Create a CustomAgent object which loads model weights from file. def save(self): ... # Call your agent's function that saves it to a file - def export(self): - ... - # Call your agent's function that exports it to a transportable file format. +You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession` and :py:mod:`primaite.common.enums` to capture your new agent identifiers. -You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession` class to capture your new agent identifier. +.. code-block:: python + :emphasize-lines: 17, 18 + # src/primaite/common/enums.py + class AgentIdentifier(Enum): + """The Red Agent algo/class.""" + A2C = 1 + "Advantage Actor Critic" + PPO = 2 + "Proximal Policy Optimization" + HARDCODED = 3 + "The Hardcoded agents" + DO_NOTHING = 4 + "The DoNothing agents" + RANDOM = 5 + "The RandomAgent" + DUMMY = 6 + "The DummyAgent" + CUSTOM_AGENT = 7 + "Your custom agent" +.. code-block:: python + :emphasize-lines: 3, 11, 12 + # src/primaite_session.py -The selection of which agent type to use is made via the training config file. In order to train a user generated agent, -the run_generic function should be selected, and should be modified (typically) to be: + from primaite.agents.my_custom_agent import CustomAgent -.. code:: python + # ... - agent = MyAgent(environment, num_steps) - for episode in range(0, num_episodes): - agent.learn() - env.close() - save_agent(agent) + def setup(self): + """Performs the session setup.""" + if self._training_config.agent_framework == AgentFramework.CUSTOM: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}") + if self._training_config.agent_identifier == AgentIdentifier.CUSTOM_AGENT: + self._agent_session = CustomAgent(self._training_config_path, self._lay_down_config_path) + if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") + if self._training_config.action_type == ActionType.NODE: + # Deterministic Hardcoded Agent with Node Action Space + self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) -Where: +Finally, specify your agent in your training config. -* *MyAgent* is the user created agent -* *environment* is the :class:`~primaite.environment.primaite_env.Primaite` environment -* *num_episodes* is the number of episodes in the session, as defined in the training config file -* *num_steps* is the number of steps in an episode, as defined in the training config file -* the *.learn()* function should be defined in the user created agent -* the *env.close()* function is defined within PrimAITE -* the *save_agent()* assumes that a *save()* function has been defined in the user created agent. If not, this line can - be ommitted (although it is encouraged, since it will allow the agent to be saved and ported) +.. code-block:: yaml -The code below provides a suggested format for the learn() function within the user created agent. -It's important to include the *self.environment.reset()* call within the episode loop in order that the -environment is reset between episodes. Note that the example below should not be considered exhaustive. + # ~/primaite/config/path/to/your/config_main.yaml -.. code:: python + # Training Config File - def learn(self) : + agent_framework: CUSTOM + agent_identifier: CUSTOM_AGENT + random_red_agent: False + # ... - # pre-reqs - - # reset the environment - self.environment.reset() - done = False - - for step in range(max_steps): - # calculate the action - action = ... - - # execute the environment step - new_state, reward, done, info = self.environment.step(action) - - # algorithm updates - ... - - # update to our new state - state = new_state - - # if done, finish episode - if done == True: - break +Now you can `Run a PrimAITE Session` with your custom agent by passing in the custom ``config_main``.