Files
PrimAITE/src/primaite/game/agent/interface.py

74 lines
3.0 KiB
Python
Raw Normal View History

2023-09-26 12:54:56 +01:00
# TODO: remove this comment... This is just here to point out that I've named this 'actor' rather than 'agent'
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
# If you disagree, make a comment in the PR review and we can discuss
from abc import ABC, abstractmethod
2023-10-06 10:36:29 +01:00
from typing import Any, Dict, List, Optional, TypeAlias, Union
2023-10-02 17:21:43 +01:00
import numpy as np
2023-09-26 12:54:56 +01:00
2023-10-02 17:21:43 +01:00
from primaite.game.agent.actions import ActionManager
2023-09-26 12:54:56 +01:00
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
2023-10-06 10:36:29 +01:00
ObsType: TypeAlias = Union[Dict, np.ndarray]
2023-09-26 12:54:56 +01:00
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
def __init__(
self,
2023-10-02 17:21:43 +01:00
action_space: Optional[ActionManager],
2023-09-26 12:54:56 +01:00
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
2023-10-02 17:21:43 +01:00
self.action_space: Optional[ActionManager] = action_space
2023-09-26 12:54:56 +01:00
self.observation_space: Optional[ObservationSpace] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
2023-10-02 17:21:43 +01:00
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
2023-10-06 10:36:29 +01:00
def convert_state_to_obs(self, state: Dict) -> ObsType:
2023-10-02 17:21:43 +01:00
"""
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
2023-10-06 10:36:29 +01:00
def calculate_reward_from_state(self, state: Dict) -> float:
2023-10-02 17:21:43 +01:00
return self.reward_function.calculate(state)
@abstractmethod
2023-10-06 10:36:29 +01:00
def get_action(self, obs: ObsType, reward: float = None):
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
2023-10-02 17:21:43 +01:00
# then use a bespoke conversion to take 1-40 int back into CAOS action
2023-10-06 10:36:29 +01:00
return ("NODE", "SERVICE", "SCAN", "<fake-node-sid>", "<fake-service-sid>")
2023-10-02 17:21:43 +01:00
@abstractmethod
def format_request(self, action) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
2023-10-06 10:36:29 +01:00
return ["network", "nodes", "<fake-node-uuid>", "file_system", "folder", "root", "scan"]
2023-10-02 17:21:43 +01:00
2023-09-26 12:54:56 +01:00
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
...
2023-10-06 10:36:29 +01:00
2023-10-02 17:21:43 +01:00
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
2023-10-06 10:36:29 +01:00
def get_action(self, obs: ObsType, reward: float = None):
2023-10-02 17:21:43 +01:00
return self.action_space.space.sample()
2023-09-26 12:54:56 +01:00
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""
...