diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f034f9ea..700a45fd 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -50,9 +50,10 @@ game_config: - type: DUMMY agent_settings: - start_step: 5 - frequency: 4 - variance: 3 + start_settings: + start_step: 5 + frequency: 4 + variance: 3 - ref: client_1_data_manipulation_red_bot team: RED @@ -106,9 +107,10 @@ game_config: - type: DUMMY agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_step: 25 - frequency: 20 - variance: 5 + start_settings: + start_step: 25 + frequency: 20 + variance: 5 - ref: defender team: BLUE diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index c591c554..70eb1980 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -25,6 +25,24 @@ class AgentExecutionDefinition(BaseModel): "The probability of data manipulation succeeding." +class AgentStartSettings(BaseModel): + """Configuration values for when an agent starts performing actions.""" + + start_step: int = 5 + "The timestep at which an agent begins performing it's actions" + frequency: int = 5 + "The number of timesteps to wait between performing actions" + variance: int = 0 + "The amount the frequency can randomly change to" + + +class AgentSettings(BaseModel): + """Settings for configuring the operation of an agent.""" + + start_settings: Optional[AgentStartSettings] = None + "Configuration for when an agent begins performing it's actions" + + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -35,6 +53,7 @@ class AbstractAgent(ABC): observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], execution_definition: Optional[AgentExecutionDefinition], + agent_settings: Optional[AgentSettings], ) -> None: """ Initialize an agent. @@ -57,6 +76,8 @@ class AbstractAgent(ABC): # by for example specifying target ip addresses, or converting a node ID into a uuid self.execution_definition = execution_definition or AgentExecutionDefinition() + self.agent_settings = agent_settings or AgentSettings() + def convert_state_to_obs(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 5f3fb7b9..9701fec9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,7 +10,13 @@ from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent +from primaite.game.agent.interface import ( + AbstractAgent, + AgentExecutionDefinition, + AgentSettings, + DataManipulationAgent, + RandomAgent, +) from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -439,6 +445,7 @@ class PrimaiteSession: rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {})) + agent_settings = AgentSettings(**agent_cfg.get("agent_settings", {})) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": @@ -449,6 +456,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, execution_definition=execution_definition, + agent_settings=agent_settings, ) sess.agents.append(new_agent) elif agent_type == "GATERLAgent": @@ -458,6 +466,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, execution_definition=execution_definition, + agent_settings=agent_settings, ) sess.agents.append(new_agent) sess.rl_agent = new_agent @@ -468,6 +477,7 @@ class PrimaiteSession: observation_space=obs_space, reward_function=rew_function, execution_definition=execution_definition, + agent_settings=agent_settings, ) sess.agents.append(new_agent) else: