Begin implementing training loop in session
This commit is contained in:
@@ -52,6 +52,7 @@ class SB3Policy(PolicyABC):
|
||||
num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session
|
||||
deterministic = True # TODO: populate values once I figure out how to get them from the config / session
|
||||
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for episode in range(num_episodes):
|
||||
obs = self.session.env.reset()
|
||||
for step in range(time_steps):
|
||||
|
||||
@@ -9,6 +9,7 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -59,31 +60,51 @@ class PrimaiteSession:
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.rl_agent: AbstractAgent
|
||||
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
|
||||
# self.rl_agent: AbstractAgent
|
||||
# """The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteSessionOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_applications: Dict[str, Application] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
raise NotImplementedError
|
||||
# n_learn_steps = self.training_options.n_learn_steps
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
# n_eval_steps = self.training_options.n_eval_steps
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
if n_learn_episodes > 0:
|
||||
self.policy.learn()
|
||||
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval()
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user