Finished writing custom agent example.

This commit is contained in:
Marek Wolan
2023-07-10 14:56:06 +01:00
parent 43a4f93626
commit 30bcdba429

View File

@@ -2,14 +2,21 @@
=============
**Integrating a user defined blue agent**
Integrating a user defined blue agent
*************************************
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, ``save()``, and ``export()`` methods.
.. note::
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
Below is a barebones example of a custom agent implementation:
.. code:: python
# src/primaite/agents/my_custom_agent.py
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
@@ -63,72 +70,69 @@ Below is a barebones example of a custom agent implementation:
@classmethod
def load(cls, path):
...
#
# Create a CustomAgent object which loads model weights from file.
def save(self):
...
# Call your agent's function that saves it to a file
def export(self):
...
# Call your agent's function that exports it to a transportable file format.
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` and :py:mod:`primaite.common.enums` to capture your new agent identifiers.
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` class to capture your new agent identifier.
.. code-block:: python
:emphasize-lines: 17, 18
# src/primaite/common/enums.py
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
CUSTOM_AGENT = 7
"Your custom agent"
.. code-block:: python
:emphasize-lines: 3, 11, 12
# src/primaite_session.py
The selection of which agent type to use is made via the training config file. In order to train a user generated agent,
the run_generic function should be selected, and should be modified (typically) to be:
from primaite.agents.my_custom_agent import CustomAgent
.. code:: python
# ...
agent = MyAgent(environment, num_steps)
for episode in range(0, num_episodes):
agent.learn()
env.close()
save_agent(agent)
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.CUSTOM_AGENT:
self._agent_session = CustomAgent(self._training_config_path, self._lay_down_config_path)
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
Where:
Finally, specify your agent in your training config.
* *MyAgent* is the user created agent
* *environment* is the :class:`~primaite.environment.primaite_env.Primaite` environment
* *num_episodes* is the number of episodes in the session, as defined in the training config file
* *num_steps* is the number of steps in an episode, as defined in the training config file
* the *.learn()* function should be defined in the user created agent
* the *env.close()* function is defined within PrimAITE
* the *save_agent()* assumes that a *save()* function has been defined in the user created agent. If not, this line can
be ommitted (although it is encouraged, since it will allow the agent to be saved and ported)
.. code-block:: yaml
The code below provides a suggested format for the learn() function within the user created agent.
It's important to include the *self.environment.reset()* call within the episode loop in order that the
environment is reset between episodes. Note that the example below should not be considered exhaustive.
# ~/primaite/config/path/to/your/config_main.yaml
.. code:: python
# Training Config File
def learn(self) :
agent_framework: CUSTOM
agent_identifier: CUSTOM_AGENT
random_red_agent: False
# ...
# pre-reqs
# reset the environment
self.environment.reset()
done = False
for step in range(max_steps):
# calculate the action
action = ...
# execute the environment step
new_state, reward, done, info = self.environment.step(action)
# algorithm updates
...
# update to our new state
state = new_state
# if done, finish episode
if done == True:
break
Now you can `Run a PrimAITE Session<run a primaite session>` with your custom agent by passing in the custom ``config_main``.