Pass agent settings from config to agent
This commit is contained in:
@@ -50,9 +50,10 @@ game_config:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
start_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
@@ -106,9 +107,10 @@ game_config:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
|
||||
@@ -25,6 +25,24 @@ class AgentExecutionDefinition(BaseModel):
|
||||
"The probability of data manipulation succeeding."
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
"""Settings for configuring the operation of an agent."""
|
||||
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
@@ -35,6 +53,7 @@ class AbstractAgent(ABC):
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
execution_definition: Optional[AgentExecutionDefinition],
|
||||
agent_settings: Optional[AgentSettings],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
@@ -57,6 +76,8 @@ class AbstractAgent(ABC):
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = execution_definition or AgentExecutionDefinition()
|
||||
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
@@ -10,7 +10,13 @@ from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent
|
||||
from primaite.game.agent.interface import (
|
||||
AbstractAgent,
|
||||
AgentExecutionDefinition,
|
||||
AgentSettings,
|
||||
DataManipulationAgent,
|
||||
RandomAgent,
|
||||
)
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
@@ -439,6 +445,7 @@ class PrimaiteSession:
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
|
||||
execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {}))
|
||||
agent_settings = AgentSettings(**agent_cfg.get("agent_settings", {}))
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
@@ -449,6 +456,7 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
execution_definition=execution_definition,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
@@ -458,6 +466,7 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
execution_definition=execution_definition,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
@@ -468,6 +477,7 @@ class PrimaiteSession:
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
execution_definition=execution_definition,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user