diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 9c6b49ae..151e860d 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -52,6 +52,7 @@ class SB3Policy(PolicyABC): num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session deterministic = True # TODO: populate values once I figure out how to get them from the config / session + # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB for episode in range(num_episodes): obs = self.session.env.reset() for step in range(time_steps): diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 459d9668..a088d05e 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -9,6 +9,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, RandomAgent from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction +from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -59,31 +60,51 @@ class PrimaiteSession: def __init__(self): self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + self.agents: List[AbstractAgent] = [] """List of agents.""" - self.rl_agent: AbstractAgent - """The agent from the list which communicates with GATE to perform reinforcement learning.""" + + # self.rl_agent: AbstractAgent + # """The agent from the list which communicates with GATE to perform reinforcement learning.""" + self.step_counter: int = 0 """Current timestep within the episode.""" + self.episode_counter: int = 0 """Current episode number.""" + self.options: PrimaiteSessionOptions """Special options that apply for the entire game.""" + self.training_options: TrainingOptions """Options specific to agent training.""" + self.policy: PolicyABC + """The reinforcement learning policy.""" + self.ref_map_nodes: Dict[str, Node] = {} """Mapping from unique node reference name to node object. Used when parsing config files.""" + self.ref_map_services: Dict[str, Service] = {} """Mapping from human-readable service reference to service object. Used for parsing config files.""" + self.ref_map_applications: Dict[str, Application] = {} """Mapping from human-readable application reference to application object. Used for parsing config files.""" + self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" def start_session(self) -> None: """Commence the training session, this gives the GATE client control over the simulation/agent loop.""" - raise NotImplementedError + # n_learn_steps = self.training_options.n_learn_steps + n_learn_episodes = self.training_options.n_learn_episodes + # n_eval_steps = self.training_options.n_eval_steps + n_eval_episodes = self.training_options.n_eval_episodes + if n_learn_episodes > 0: + self.policy.learn() + + if n_eval_episodes > 0: + self.policy.eval() def step(self): """