Begin implementing training loop in session

This commit is contained in:
Marek Wolan
2023-11-13 16:35:35 +00:00
parent 707f2b59af
commit 08e88e52b0
2 changed files with 25 additions and 3 deletions

View File

@@ -52,6 +52,7 @@ class SB3Policy(PolicyABC):
num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session
deterministic = True # TODO: populate values once I figure out how to get them from the config / session
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for episode in range(num_episodes):
obs = self.session.env.reset()
for step in range(time_steps):

View File

@@ -9,6 +9,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -59,31 +60,51 @@ class PrimaiteSession:
def __init__(self):
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agent: AbstractAgent
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
# self.rl_agent: AbstractAgent
# """The agent from the list which communicates with GATE to perform reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
raise NotImplementedError
# n_learn_steps = self.training_options.n_learn_steps
n_learn_episodes = self.training_options.n_learn_episodes
# n_eval_steps = self.training_options.n_eval_steps
n_eval_episodes = self.training_options.n_eval_episodes
if n_learn_episodes > 0:
self.policy.learn()
if n_eval_episodes > 0:
self.policy.eval()
def step(self):
"""