Pass execution definition from config to agent
This commit is contained in:
@@ -58,6 +58,10 @@ game_config:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
execution_definition:
|
||||
port_scan_p_of_success: 0.1
|
||||
data_manipulation_p_of_success: 0.1
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
@@ -11,6 +12,11 @@ from primaite.game.agent.rewards import RewardFunction
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AgentExecutionDefinition(BaseModel):
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
@@ -20,6 +26,7 @@ class AbstractAgent(ABC):
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
execution_definition: Optional[AgentExecutionDefinition]
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
@@ -40,7 +47,7 @@ class AbstractAgent(ABC):
|
||||
|
||||
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
self.execution_definition = execution_definition or AgentExecutionDefinition()
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -110,7 +117,11 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
pass
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
@@ -438,6 +438,8 @@ class PrimaiteSession:
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
|
||||
execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {}))
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
@@ -446,6 +448,7 @@ class PrimaiteSession:
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
execution_definition=execution_definition,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
@@ -454,15 +457,17 @@ class PrimaiteSession:
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
execution_definition=execution_definition,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
execution_definition=execution_definition,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
else:
|
||||
|
||||
@@ -135,3 +135,6 @@ class WebBrowser(Application):
|
||||
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
|
||||
self.latest_response = payload
|
||||
return True
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user