Files
PrimAITE/src/primaite/game/agent/interface.py

74 lines
3.0 KiB
Python
Raw Normal View History

2023-09-26 12:54:56 +01:00
# TODO: remove this comment... This is just here to point out that I've named this 'actor' rather than 'agent'
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
# If you disagree, make a comment in the PR review and we can discuss
from abc import ABC, abstractmethod
2023-10-02 17:21:43 +01:00
from typing import Any, Dict, List, Optional, Union, TypeAlias
import numpy as np
2023-09-26 12:54:56 +01:00
2023-10-02 17:21:43 +01:00
from primaite.game.agent.actions import ActionManager
2023-09-26 12:54:56 +01:00
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
2023-10-02 17:21:43 +01:00
ObsType:TypeAlias = Union[Dict, np.ndarray]
2023-09-26 12:54:56 +01:00
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
def __init__(
self,
2023-10-02 17:21:43 +01:00
action_space: Optional[ActionManager],
2023-09-26 12:54:56 +01:00
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
2023-10-02 17:21:43 +01:00
self.action_space: Optional[ActionManager] = action_space
2023-09-26 12:54:56 +01:00
self.observation_space: Optional[ObservationSpace] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
2023-10-02 17:21:43 +01:00
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
def get_obs_from_state(self, state:Dict) -> ObsType:
"""
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
def get_reward_from_state(self, state:Dict) -> float:
return self.reward_function.calculate(state)
@abstractmethod
def get_action(self, obs:ObsType, reward:float=None):
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 1-40,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ('NODE', 'SERVICE', 'SCAN', '<fake-node-sid>', '<fake-service-sid>')
@abstractmethod
def format_request(self, action) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
return ['network', 'nodes', '<fake-node-uuid>', 'file_system', 'folder', 'root', 'scan']
2023-09-26 12:54:56 +01:00
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
...
2023-10-02 17:21:43 +01:00
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs:ObsType, reward:float=None):
return self.action_space.space.sample()
2023-09-26 12:54:56 +01:00
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""
...