Finished writing custom agent example.
This commit is contained in:
@@ -2,14 +2,21 @@
|
|||||||
=============
|
=============
|
||||||
|
|
||||||
|
|
||||||
**Integrating a user defined blue agent**
|
Integrating a user defined blue agent
|
||||||
|
*************************************
|
||||||
|
|
||||||
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, ``save()``, and ``export()`` methods.
|
.. note::
|
||||||
|
|
||||||
|
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
|
||||||
|
|
||||||
|
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
|
||||||
|
|
||||||
Below is a barebones example of a custom agent implementation:
|
Below is a barebones example of a custom agent implementation:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
|
# src/primaite/agents/my_custom_agent.py
|
||||||
|
|
||||||
from primaite.agents.agent import AgentSessionABC
|
from primaite.agents.agent import AgentSessionABC
|
||||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||||
|
|
||||||
@@ -63,72 +70,69 @@ Below is a barebones example of a custom agent implementation:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path):
|
def load(cls, path):
|
||||||
...
|
...
|
||||||
#
|
# Create a CustomAgent object which loads model weights from file.
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
...
|
...
|
||||||
# Call your agent's function that saves it to a file
|
# Call your agent's function that saves it to a file
|
||||||
|
|
||||||
def export(self):
|
|
||||||
...
|
|
||||||
# Call your agent's function that exports it to a transportable file format.
|
|
||||||
|
|
||||||
|
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` and :py:mod:`primaite.common.enums` to capture your new agent identifiers.
|
||||||
|
|
||||||
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` class to capture your new agent identifier.
|
.. code-block:: python
|
||||||
|
:emphasize-lines: 17, 18
|
||||||
|
|
||||||
|
# src/primaite/common/enums.py
|
||||||
|
|
||||||
|
class AgentIdentifier(Enum):
|
||||||
|
"""The Red Agent algo/class."""
|
||||||
|
A2C = 1
|
||||||
|
"Advantage Actor Critic"
|
||||||
|
PPO = 2
|
||||||
|
"Proximal Policy Optimization"
|
||||||
|
HARDCODED = 3
|
||||||
|
"The Hardcoded agents"
|
||||||
|
DO_NOTHING = 4
|
||||||
|
"The DoNothing agents"
|
||||||
|
RANDOM = 5
|
||||||
|
"The RandomAgent"
|
||||||
|
DUMMY = 6
|
||||||
|
"The DummyAgent"
|
||||||
|
CUSTOM_AGENT = 7
|
||||||
|
"Your custom agent"
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
:emphasize-lines: 3, 11, 12
|
||||||
|
|
||||||
|
# src/primaite_session.py
|
||||||
|
|
||||||
The selection of which agent type to use is made via the training config file. In order to train a user generated agent,
|
from primaite.agents.my_custom_agent import CustomAgent
|
||||||
the run_generic function should be selected, and should be modified (typically) to be:
|
|
||||||
|
|
||||||
.. code:: python
|
# ...
|
||||||
|
|
||||||
agent = MyAgent(environment, num_steps)
|
def setup(self):
|
||||||
for episode in range(0, num_episodes):
|
"""Performs the session setup."""
|
||||||
agent.learn()
|
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||||
env.close()
|
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||||
save_agent(agent)
|
if self._training_config.agent_identifier == AgentIdentifier.CUSTOM_AGENT:
|
||||||
|
self._agent_session = CustomAgent(self._training_config_path, self._lay_down_config_path)
|
||||||
|
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||||
|
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||||
|
if self._training_config.action_type == ActionType.NODE:
|
||||||
|
# Deterministic Hardcoded Agent with Node Action Space
|
||||||
|
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||||
|
|
||||||
Where:
|
Finally, specify your agent in your training config.
|
||||||
|
|
||||||
* *MyAgent* is the user created agent
|
.. code-block:: yaml
|
||||||
* *environment* is the :class:`~primaite.environment.primaite_env.Primaite` environment
|
|
||||||
* *num_episodes* is the number of episodes in the session, as defined in the training config file
|
|
||||||
* *num_steps* is the number of steps in an episode, as defined in the training config file
|
|
||||||
* the *.learn()* function should be defined in the user created agent
|
|
||||||
* the *env.close()* function is defined within PrimAITE
|
|
||||||
* the *save_agent()* assumes that a *save()* function has been defined in the user created agent. If not, this line can
|
|
||||||
be ommitted (although it is encouraged, since it will allow the agent to be saved and ported)
|
|
||||||
|
|
||||||
The code below provides a suggested format for the learn() function within the user created agent.
|
# ~/primaite/config/path/to/your/config_main.yaml
|
||||||
It's important to include the *self.environment.reset()* call within the episode loop in order that the
|
|
||||||
environment is reset between episodes. Note that the example below should not be considered exhaustive.
|
|
||||||
|
|
||||||
.. code:: python
|
# Training Config File
|
||||||
|
|
||||||
def learn(self) :
|
agent_framework: CUSTOM
|
||||||
|
agent_identifier: CUSTOM_AGENT
|
||||||
|
random_red_agent: False
|
||||||
|
# ...
|
||||||
|
|
||||||
# pre-reqs
|
Now you can `Run a PrimAITE Session<run a primaite session>` with your custom agent by passing in the custom ``config_main``.
|
||||||
|
|
||||||
# reset the environment
|
|
||||||
self.environment.reset()
|
|
||||||
done = False
|
|
||||||
|
|
||||||
for step in range(max_steps):
|
|
||||||
# calculate the action
|
|
||||||
action = ...
|
|
||||||
|
|
||||||
# execute the environment step
|
|
||||||
new_state, reward, done, info = self.environment.step(action)
|
|
||||||
|
|
||||||
# algorithm updates
|
|
||||||
...
|
|
||||||
|
|
||||||
# update to our new state
|
|
||||||
state = new_state
|
|
||||||
|
|
||||||
# if done, finish episode
|
|
||||||
if done == True:
|
|
||||||
break
|
|
||||||
|
|||||||
Reference in New Issue
Block a user