Start implementing green agent logic for UC2
This commit is contained in:
@@ -1,10 +1,70 @@
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GreenWebBrowsingAgent(AbstractScriptedAgent):
|
||||
class GreenUC2Agent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to send web requests to a target node."""
|
||||
|
||||
class GreenUC2AgentSettings(pydantic.BaseModel):
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
random_seed: Optional[int] = None
|
||||
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
|
||||
if not abs(sum(v.values()) - 1) < 1e-6:
|
||||
raise ValueError(f"Green action probabilities must sum to 1")
|
||||
return v
|
||||
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]:
|
||||
if not all((i in v) for i in range(len(v))):
|
||||
raise ValueError(
|
||||
"Green action probabilities must be defined as a mapping where the keys are consecutive integers "
|
||||
"from 0 to N."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
settings: Dict = {},
|
||||
) -> None:
|
||||
# If the action probabilities are not specified, create equal probabilities for all actions
|
||||
if "action_probabilities" not in settings:
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
|
||||
# If seed not specified, set it to None so that numpy chooses a random one.
|
||||
settings.setdefault("random_seed")
|
||||
|
||||
self.settings = GreenUC2Agent.GreenUC2AgentSettings(settings)
|
||||
|
||||
self.rng = np.random.default_rng(self.settings.random_seed)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.array[self.settings.action_probabilities.values()]
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0) -> Tuple[str, Dict]:
|
||||
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
|
||||
return self.action_manager.get_action(choice)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents import GreenUC2Agent
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -392,9 +393,9 @@ class PrimaiteGame:
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
if agent_type == "GreenUC2Agent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
new_agent = RandomAgent(
|
||||
new_agent = GreenUC2Agent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
|
||||
Reference in New Issue
Block a user