Start implementing green agent logic for UC2

This commit is contained in:
Marek Wolan
2024-02-26 20:08:13 +00:00
parent 922298eaf0
commit c54f82fb1b
2 changed files with 64 additions and 3 deletions

View File

@@ -1,10 +1,70 @@
"""Agents with predefined behaviours."""
from typing import Dict, Optional, Tuple
import numpy as np
import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
class GreenWebBrowsingAgent(AbstractScriptedAgent):
class GreenUC2Agent(AbstractScriptedAgent):
"""Scripted agent which attempts to send web requests to a target node."""
class GreenUC2AgentSettings(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]:
if not abs(sum(v.values()) - 1) < 1e-6:
raise ValueError(f"Green action probabilities must sum to 1")
return v
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]:
if not all((i in v) for i in range(len(v))):
raise ValueError(
"Green action probabilities must be defined as a mapping where the keys are consecutive integers "
"from 0 to N."
)
def __init__(
self,
agent_name: str,
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
settings: Dict = {},
) -> None:
# If the action probabilities are not specified, create equal probabilities for all actions
if "action_probabilities" not in settings:
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
self.settings = GreenUC2Agent.GreenUC2AgentSettings(settings)
self.rng = np.random.default_rng(self.settings.random_seed)
# convert probabilities from
self.probabilities = np.array[self.settings.action_probabilities.values()]
super().__init__(agent_name, action_space, observation_space, reward_function)
def get_action(self, obs: ObsType, reward: float = 0) -> Tuple[str, Dict]:
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
return self.action_manager.get_action(choice)
raise NotImplementedError

View File

@@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import GreenUC2Agent
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -392,9 +393,9 @@ class PrimaiteGame:
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
if agent_type == "GreenUC2Agent":
# TODO: implement non-random agents and fix this parsing
new_agent = RandomAgent(
new_agent = GreenUC2Agent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,