Pass execution definition from config to agent

This commit is contained in:
Jake Walker
2023-11-16 13:26:30 +00:00
parent 23fd9c3839
commit 1c5ff66d26
4 changed files with 27 additions and 4 deletions

View File

@@ -58,6 +58,10 @@ game_config:
team: RED
type: RedDatabaseCorruptingAgent
execution_definition:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
observation_space:
type: UC2RedObservation
options:

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
import numpy as np
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
@@ -11,6 +12,11 @@ from primaite.game.agent.rewards import RewardFunction
ObsType: TypeAlias = Union[Dict, np.ndarray]
class AgentExecutionDefinition(BaseModel):
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
@@ -20,6 +26,7 @@ class AbstractAgent(ABC):
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
execution_definition: Optional[AgentExecutionDefinition]
) -> None:
"""
Initialize an agent.
@@ -40,7 +47,7 @@ class AbstractAgent(ABC):
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
self.execution_definition = execution_definition or AgentExecutionDefinition()
def convert_state_to_obs(self, state: Dict) -> ObsType:
"""
@@ -110,7 +117,11 @@ class RandomAgent(AbstractScriptedAgent):
return self.action_space.get_action(self.action_space.space.sample())
class DataManipulationAgent(AbstractScriptedAgent):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
return self.action_space.get_action(self.action_space.space.sample())
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""

View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node
@@ -438,6 +438,8 @@ class PrimaiteSession:
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {}))
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
@@ -446,6 +448,7 @@ class PrimaiteSession:
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
execution_definition=execution_definition,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
@@ -454,15 +457,17 @@ class PrimaiteSession:
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
execution_definition=execution_definition,
)
sess.agents.append(new_agent)
sess.rl_agent = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
execution_definition=execution_definition,
)
sess.agents.append(new_agent)
else:

View File

@@ -135,3 +135,6 @@ class WebBrowser(Application):
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
self.latest_response = payload
return True
def execute(self):
pass