Get primaite session step working
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gym import spaces
|
||||
|
||||
from primaite.game.session import PrimaiteSession
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class ExecutionDefiniton(ABC):
|
||||
"""
|
||||
@@ -59,7 +61,7 @@ class DoNothingAction(AbstractAction):
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self) -> List[str]:
|
||||
def form_request(self, **kwargs) -> List[str]:
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
@@ -86,56 +88,56 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str
|
||||
@@ -149,26 +151,26 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, **kwargs)
|
||||
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "restore"
|
||||
|
||||
|
||||
@@ -190,34 +192,40 @@ class NodeFileAbstractAction(AbstractAction):
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "checkhash"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "delete"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "repair"
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "corrupt"
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction):
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
@@ -232,25 +240,25 @@ class NodeAbstractAction(AbstractAction):
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "reset"
|
||||
|
||||
|
||||
@@ -371,6 +379,7 @@ class ActionManager:
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
"NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
"NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_FILE_CORRUPT": NodeFileCorruptAction,
|
||||
"NODE_FOLDER_SCAN": NodeFolderScanAction,
|
||||
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
@@ -387,7 +396,7 @@ class ActionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: PrimaiteSession, # reference to session for looking up stuff
|
||||
session: "PrimaiteSession", # reference to session for looking up stuff
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
@@ -400,7 +409,7 @@ class ActionManager:
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
self.session: PrimaiteSession = session
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
@@ -417,7 +426,8 @@ class ActionManager:
|
||||
for nic_uuid, nic_obj in nics.items():
|
||||
self.ip_address_list.append(nic_obj.ip_address)
|
||||
|
||||
action_args = {
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(node_uuids),
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
@@ -427,10 +437,21 @@ class ActionManager:
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips": len(self.ip_address_list),
|
||||
"max_acl_rules":max_acl_rules,
|
||||
"max_nics_per_node": max_nics_per_node,
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_type in actions:
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args)
|
||||
for act_spec in actions:
|
||||
# each action is provided into the action space config like this:
|
||||
# - type: ACTION_TYPE
|
||||
# options:
|
||||
# option_1: value1
|
||||
# option_2: value2
|
||||
# where `type` decides which AbstractAction subclass should be used
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get('type')
|
||||
act_options = act_spec.get('options', {})
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
@@ -448,15 +469,41 @@ class ActionManager:
|
||||
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[AbstractAction, Dict]]:
|
||||
) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Generate a list of all the possible actions that could be taken.
|
||||
|
||||
This enumerates all actions all combinations of parametes you could choose for those actions. The output
|
||||
of this function is intended to populate the self.action_map parameter in the situation where the user provides
|
||||
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
|
||||
|
||||
The enumeration relies on the Actions' `shape` attribute.
|
||||
|
||||
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
|
||||
An example output could be:
|
||||
{0: ("DONOTHING", {'dummy': 0}),
|
||||
1: ("NODE_OS_SCAN", {'node_id': 0}),
|
||||
2: ("NODE_OS_SCAN", {'node_id': 1}),
|
||||
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
|
||||
... #etc...
|
||||
}
|
||||
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
|
||||
"""
|
||||
all_action_possibilities = []
|
||||
for action in self.actions.values():
|
||||
param_names = (list(action.shape.keys()),)
|
||||
for act_name, action in self.actions.items():
|
||||
param_names = list(action.shape.keys())
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
itertools.product(action.shape.values())
|
||||
all_action_possibilities.append((action, {}))
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(
|
||||
act_name, {param_names[i]:param_combinations[j][i] for i in range(len(param_names))}
|
||||
) for j in range(len(param_combinations))]
|
||||
)
|
||||
|
||||
return {i:p for i,p in enumerate(all_action_possibilities)}
|
||||
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format"""
|
||||
@@ -517,21 +564,16 @@ class ActionManager:
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: PrimaiteSession, cfg: Dict) -> "ActionManager":
|
||||
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
|
||||
obj = cls(
|
||||
session=session,
|
||||
actions=cfg["action_list"],
|
||||
node_uuids=cfg["options"]["nodes"],
|
||||
max_folders_per_node=cfg["options"]["max_folders_per_node"],
|
||||
max_files_per_folder=cfg["options"]["max_files_per_folder"],
|
||||
max_services_per_node=cfg["options"]["max_services_per_node"],
|
||||
max_nics_per_node=cfg["options"]["max_nics_per_node"],
|
||||
max_acl_rules=cfg["options"]["max_acl_rules"],
|
||||
max_X=cfg["options"]["max_X"],
|
||||
protocols=session.options.ports,
|
||||
ports=session.options.protocols,
|
||||
# node_uuids=cfg["options"]["node_uuids"],
|
||||
**cfg['options'],
|
||||
protocols=session.options.protocols,
|
||||
ports=session.options.ports,
|
||||
ip_address_list=None,
|
||||
act_map=cfg["action_map"],
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
|
||||
# If you disagree, make a comment in the PR review and we can discuss
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -41,17 +41,17 @@ class AbstractAgent(ABC):
|
||||
return self.reward_function.calculate(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None):
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("NODE", "SERVICE", "SCAN", "<fake-node-sid>", "<fake-service-sid>")
|
||||
return ("DO_NOTHING", {} )
|
||||
|
||||
@abstractmethod
|
||||
def format_request(self, action) -> List[str]:
|
||||
def format_request(self, action:Tuple[str,Dict], options:Dict[str, int]) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
return ["network", "nodes", "<fake-node-uuid>", "file_system", "folder", "root", "scan"]
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
@@ -63,8 +63,8 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None):
|
||||
return self.action_space.space.sample()
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
|
||||
@@ -5,13 +5,30 @@ class AbstractReward():
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state:Dict) -> float:
|
||||
return 0.3
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
return -0.1
|
||||
|
||||
class RewardFunction():
|
||||
__rew_class_identifiers:Dict[str,type[AbstractReward]] = {
|
||||
"DUMMY" : DummyReward
|
||||
}
|
||||
def __init__(self, reward_function:AbstractReward):
|
||||
self.reward: AbstractReward = reward_function
|
||||
|
||||
def calculate(self, state:Dict) -> float:
|
||||
return self.reward.calculate(state)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg:Dict) -> "RewardFunction":
|
||||
for rew_component_cfg in cfg['reward_components']:
|
||||
rew_type = rew_component_cfg['type']
|
||||
rew_component = cls.__rew_class_identifiers[rew_type]()
|
||||
new = cls(reward_function=rew_component)
|
||||
return new
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import (
|
||||
AclObservation,
|
||||
FileObservation,
|
||||
@@ -25,6 +25,7 @@ from primaite.game.agent.observations import (
|
||||
UC2BlueObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -74,10 +75,10 @@ class PrimaiteSession:
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
agent_action = agent.get_action(agent_obs, agent_reward)
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
agent_request = agent.format_request(agent_action)
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
self.simulation.apply_action(agent_request)
|
||||
@@ -88,6 +89,10 @@ class PrimaiteSession:
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports = cfg['game_config']['ports'],
|
||||
protocols = cfg['game_config']['protocols'],
|
||||
)
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
@@ -304,13 +309,33 @@ class PrimaiteSession:
|
||||
obs_space = NullObservation()
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg['options']['node_uuids'] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
for action_node_option in action_space_cfg.get('options',{}).pop('nodes', {}):
|
||||
if 'node_ref' in action_node_option:
|
||||
node_uuid = ref_map_nodes[action_node_option['node_ref']]
|
||||
action_space_cfg['options']['node_uuids'].append(node_uuid)
|
||||
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
|
||||
# we will pass node_uuids as a part of the action space config.
|
||||
# However, it's not possible to specify the node uuids directly in the config, as they are generated
|
||||
# dynamically, so we have to translate node references to uuids before passing this config on.
|
||||
|
||||
if 'action_list' in action_space_cfg:
|
||||
for action_config in action_space_cfg['action_list']:
|
||||
if 'options' in action_config:
|
||||
if 'target_router_ref' in action_config['options']:
|
||||
_target = action_config['options']['target_router_ref']
|
||||
action_config['options']['target_router_uuid'] = ref_map_nodes[_target]
|
||||
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
...
|
||||
new_agent = RandomAgent(action_space=action_space, observation_space=obs_space, reward_function=rew_function)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
...
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
@@ -318,4 +343,5 @@ class PrimaiteSession:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
|
||||
return sess
|
||||
|
||||
@@ -27,6 +27,7 @@ class Simulation(SimComponent):
|
||||
am.add_action("network", Action(func=self.network._action_manager))
|
||||
# pass through domain actions to the domain object
|
||||
am.add_action("domain", Action(func=self.domain._action_manager))
|
||||
am.add_action("do_nothing", Action(func=lambda request, context: ()))
|
||||
return am
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
Reference in New Issue
Block a user