2869 - Addressing some typos in agent declaration, and neatening up the agent structure within PrimAITE.

This commit is contained in:
Charlie Crane
2024-12-16 11:27:14 +00:00
parent c3a70be8d1
commit d9a1a0e26f
28 changed files with 67 additions and 40 deletions

View File

@@ -69,4 +69,7 @@ Changes to YAML file
Agent configurations specified within YAML files used for earlier versions of PrimAITE will need updating to be compatible with PrimAITE v4.0.0+.
Agents now follow a more standardised settings definition, so should be more consistent across YAML.
# TODO: Show changes to YAML config needed here

View File

@@ -36,7 +36,7 @@ from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]

View File

@@ -1 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.scripted_agents import (
abstract_tap,
data_manipulation_bot,
interface,
probabilistic_agent,
random_agent,
)
__all__ = ("abstract_tap", "data_manipulation_bot", "interface", "probabilistic_agent", "random_agent")

View File

@@ -3,8 +3,11 @@ from __future__ import annotations
import random
from abc import abstractmethod
from typing import Dict, Tuple
from primaite.game.agent.interface import AbstractScriptedAgent
from gymnasium.core import ObsType
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
@@ -12,12 +15,17 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
config: "AbstractTAPAgent.ConfigSchema"
agent_name: str = "Abstract_TAP"
_next_execution_timestep: int
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Abstract TAP agents."""
starting_node_name: str
next_execution_timestep: int
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Return an action to be taken in the environment."""
return super().get_action(obs=obs, timestep=timestep)
@abstractmethod
def setup_agent(self) -> None:
@@ -32,7 +40,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"):
random_timestep_increment = random.randint(
-self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance
)
self.config.next_execution_timestep = timestep + random_timestep_increment
self._next_execution_timestep = timestep + random_timestep_increment
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""

View File

@@ -24,7 +24,7 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen
@property
def next_execution_timestep(self) -> int:
"""Returns the agents next execution timestep."""
return self.config.next_execution_timestep
return self._next_execution_timestep
@property
def starting_node_name(self) -> str:

View File

@@ -3,10 +3,10 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.agent_log import AgentLog
@@ -92,7 +92,7 @@ class AgentSettings(BaseModel):
return cls(**config)
class AbstractAgent(BaseModel, identifier="Abstract_Agent"):
class AbstractAgent(BaseModel):
"""Base class for scripted and RL agents."""
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
@@ -116,6 +116,7 @@ class AbstractAgent(BaseModel, identifier="Abstract_Agent"):
"""
agent_name: str = "Abstract_Agent"
model_config = ConfigDict(extra="forbid")
history: List[AgentHistoryItem] = []
_logger: AgentLog = AgentLog(agent_name=agent_name)
_action_manager: Optional[ActionManager] = None
@@ -124,10 +125,10 @@ class AbstractAgent(BaseModel, identifier="Abstract_Agent"):
_agent_settings: Optional[AgentSettings] = None
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
super().__init_subclass__(**kwargs)
@property
def logger(self) -> AgentLog:
@@ -218,6 +219,8 @@ class AbstractAgent(BaseModel, identifier="Abstract_Agent"):
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
"""Base class for actors which generate their own behaviour."""
config: "AbstractScriptedAgent.ConfigSchema"
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for AbstractScriptedAgents."""
@@ -233,20 +236,20 @@ class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"):
"""Agent that sends observations to an RL model and receives actions from that model."""
config: "ProxyAgent.ConfigSchema"
_most_recent_action: ActType
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Proxy Agent."""
agent_name: str = "Proxy_Agent"
agent_settings = Union[AgentSettings | None] = None
most_reason_action: ActType
agent_settings: AgentSettings = None
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
action_masking: bool = agent_settings.action_masking if agent_settings else False
@property
def most_recent_action(self) -> ActType:
"""Convenience method to access the agents most recent action."""
return self.config.most_recent_action
return self._most_recent_action
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -7,14 +7,14 @@ import pydantic
from gymnasium.core import ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"):
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
config: "ProbabilisticAgent.ConfigSchema"
agent_name: str = "Probabilistic_Agent"
agent_name: str = "ProbabilisticAgent"
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""
@@ -42,10 +42,11 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"
)
return v
def __init__(self) -> None:
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
# def __init__(self, **kwargs) -> None:
# rng_seed = np.random.randint(0, 65535)
# self.rng = np.random.default_rng(rng_seed)
# self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
# super().__init_subclass__(**kwargs)
@property
def probabilities(self) -> Dict[str, int]:

View File

@@ -4,7 +4,7 @@ from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent
class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"):

View File

@@ -8,9 +8,9 @@ from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.interface import AbstractAgent, ProxyAgent
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
@@ -549,6 +549,8 @@ class PrimaiteGame:
{"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function}
)
# new_agent_cfg.update{}
print(AbstractAgent._registry)
if agent_type in AbstractAgent._registry:
new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config)
# If blue agent is created, add to game.rl_agents

View File

@@ -10,7 +10,7 @@ import numpy as np
from gymnasium.core import ActType, ObsType
from primaite import getLogger
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO

View File

@@ -7,7 +7,7 @@ from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler

View File

@@ -7,9 +7,9 @@ from ray import init as rayinit
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.interface import AbstractAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network

View File

@@ -6,8 +6,8 @@ from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.base import UserManager

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -6,8 +6,8 @@ import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer

View File

@@ -5,7 +5,7 @@ import pytest
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv

View File

@@ -17,7 +17,7 @@ from typing import Tuple
import pytest
import yaml
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus

View File

@@ -2,8 +2,8 @@
import pytest
import yaml
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv

View File

@@ -5,7 +5,7 @@ from typing import Tuple
import pytest
import yaml
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.container import Network

View File

@@ -1,11 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import (
GreenAdminDatabaseUnreachablePenalty,
WebpageUnavailablePenalty,
WebServer404Penalty,
)
from primaite.game.agent.scripted_agents.interface import AgentHistoryItem
from primaite.interface.request import RequestResponse

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer