Update agent interface to work better with envs
This commit is contained in:
@@ -4,8 +4,12 @@ training_config:
|
||||
seed: 333
|
||||
n_learn_episodes: 20
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 20
|
||||
n_eval_episodes: 5
|
||||
n_eval_steps: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
|
||||
game_config:
|
||||
@@ -108,7 +112,7 @@ game_config:
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: idk???
|
||||
type: RLAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
@@ -18,7 +16,7 @@ class AbstractAgent(ABC):
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -34,24 +32,24 @@ class AbstractAgent(ABC):
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
|
||||
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
return self.observation_manager.update(state)
|
||||
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
Use the reward function to calculate a reward from the state.
|
||||
|
||||
@@ -60,10 +58,10 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.calculate(state)
|
||||
return self.reward_function.update(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
@@ -84,7 +82,7 @@ class AbstractAgent(ABC):
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
@@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
@@ -107,4 +105,44 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.most_recent_action)
|
||||
|
||||
def store_action(self, action: ActType):
|
||||
"""
|
||||
Store the most recent action taken by the agent.
|
||||
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
|
||||
pass
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
@@ -947,15 +948,17 @@ class ObservationSpace:
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.current_observation: ObsType
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
"""
|
||||
return self.obs.observe(state)
|
||||
self.current_observation = self.obs.observe(state)
|
||||
return self.current_observation
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
@@ -963,7 +966,7 @@ class ObservationSpace:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
|
||||
@@ -238,6 +238,7 @@ class RewardFunction:
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float
|
||||
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
@@ -249,7 +250,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def update(self, state: Dict) -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -260,7 +261,8 @@ class RewardFunction:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
return total
|
||||
self.current_reward = total
|
||||
return self.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
|
||||
@@ -1,18 +1,47 @@
|
||||
from abc import ABC, abstractclassmethod, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
"""Base class for reinforcement learning agents."""
|
||||
|
||||
_registry: Dict[str, type["PolicyABC"]] = {}
|
||||
"""
|
||||
Registry of policy types, keyed by name.
|
||||
|
||||
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, name: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a policy subclass.
|
||||
|
||||
:param name: Identifier used by from_config to create an instance of the policy.
|
||||
:type name: str
|
||||
:raises ValueError: When attempting to create a policy with a duplicate name.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {name}")
|
||||
cls._registry[name] = cls
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
"""Initialize a reinforcement learning agent."""
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
:param session: The session context.
|
||||
:type session: PrimaiteSession
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
pass
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, n_episodes: int, n_time_steps: int) -> None:
|
||||
@@ -25,30 +54,30 @@ class PolicyABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
) -> None:
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
) -> None:
|
||||
def load(self) -> None:
|
||||
"""Load agent from a file."""
|
||||
pass
|
||||
|
||||
def close(
|
||||
self,
|
||||
) -> None:
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def from_config(
|
||||
cls,
|
||||
) -> "PolicyABC":
|
||||
"""Create an agent from a config file."""
|
||||
pass
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
Subclasses should not call super().from_config(), they should just handle creation form config.
|
||||
"""
|
||||
# Assume that basically the contents of training_config are passed into here.
|
||||
# I should really define a config schema class using pydantic.
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config()
|
||||
|
||||
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal, TYPE_CHECKING, Union
|
||||
"""Stable baselines 3 policy."""
|
||||
from typing import Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
@@ -7,13 +8,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"]):
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
@@ -29,8 +30,8 @@ class SB3Policy(PolicyABC):
|
||||
self._agent = self._agent_class(
|
||||
policy=policy,
|
||||
env=self.session.env,
|
||||
n_steps=...,
|
||||
seed=...,
|
||||
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
|
||||
seed=seed,
|
||||
) # TODO: populate values once I figure out how to get them from the config / session
|
||||
|
||||
def learn(self, n_episodes: int, n_time_steps: int) -> None:
|
||||
@@ -68,6 +69,6 @@ class SB3Policy(PolicyABC):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(self) -> "SB3Policy":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
pass
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
@@ -31,6 +33,58 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.session.apply_agent_actions()
|
||||
self.session.advance_timestep()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
next_obs = self.agent.observation_manager.current_observation
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = ...
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self.agent.observation_manager.current_observation
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return self.agent.observation_manager.observation_space
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
@@ -45,28 +99,29 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
rl_framework: str
|
||||
rl_algorithm: str
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
seed: Optional[int]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int
|
||||
n_eval_steps: int
|
||||
n_eval_episodes: int = 0
|
||||
n_eval_steps: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise a PrimaiteSession object."""
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
|
||||
# self.rl_agent: AbstractAgent
|
||||
# """The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
@@ -94,8 +149,10 @@ class PrimaiteSession:
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
# self.env:
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
"""Commence the training session."""
|
||||
n_learn_steps = self.training_options.n_learn_steps
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_steps = self.training_options.n_eval_steps
|
||||
@@ -119,40 +176,47 @@ class PrimaiteSession:
|
||||
4. Each agent chooses an action based on the observation.
|
||||
5. Each agent converts the action to a request.
|
||||
6. The simulation applies the requests.
|
||||
|
||||
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
|
||||
interacts with should implement a step method that calls methods used by this method. For example, if using a
|
||||
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
|
||||
``self.apply_agent_actions()``.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
return self.simulation.describe_state()
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
# 5. primaite session asks agents to produce an action based on most recent state
|
||||
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
|
||||
sim_state = self.simulation.describe_state()
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.convert_state_to_obs(sim_state)
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
for agent in self.agents:
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.calculate_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
_LOGGER.debug("Getting agent action")
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
|
||||
self.simulation.apply_request(agent_request)
|
||||
|
||||
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
self.step_counter += 1
|
||||
|
||||
@@ -161,7 +225,7 @@ class PrimaiteSession:
|
||||
return NotImplemented
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
@@ -169,7 +233,7 @@ class PrimaiteSession:
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent. Used by GATE.
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
@@ -323,7 +387,7 @@ class PrimaiteSession:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
@@ -359,15 +423,14 @@ class PrimaiteSession:
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
new_agent = RandomAgent(
|
||||
elif agent_type == "RLAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -379,4 +442,7 @@ class PrimaiteSession:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options)
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user