Get primaite session step working

This commit is contained in:
Marek Wolan
2023-10-06 20:32:52 +01:00
parent 2a8df074b9
commit 3dea9743c3
7 changed files with 520 additions and 356 deletions

View File

@@ -1,12 +1,14 @@
import itertools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gym import spaces
from primaite.game.session import PrimaiteSession
from primaite.simulator.sim_container import Simulation
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
class ExecutionDefiniton(ABC):
"""
@@ -59,7 +61,7 @@ class DoNothingAction(AbstractAction):
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self) -> List[str]:
def form_request(self, **kwargs) -> List[str]:
return ["do_nothing"]
@@ -86,56 +88,56 @@ class NodeServiceAbstractAction(AbstractAction):
class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
super().__init__(manager=manager)
def __init__(self, manager: "ActionManager", num_nodes:int, num_services:int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
class NodeFolderAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
@@ -149,26 +151,26 @@ class NodeFolderAbstractAction(AbstractAction):
class NodeFolderScanAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes, num_folders, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, **kwargs)
def __init__(self, manager: "ActionManager", num_nodes:int, num_folders:int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "restore"
@@ -190,34 +192,40 @@ class NodeFileAbstractAction(AbstractAction):
class NodeFileScanAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes, num_folders, num_files, **kwargs)
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "corrupt"
class NodeAbstractAction(AbstractAction):
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
@@ -232,25 +240,25 @@ class NodeAbstractAction(AbstractAction):
class NodeOSScanAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "scan"
class NodeShutdownAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "shutdown"
class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "start"
class NodeResetAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "reset"
@@ -371,6 +379,7 @@ class ActionManager:
"NODE_FILE_DELETE": NodeFileDeleteAction,
"NODE_FILE_REPAIR": NodeFileRepairAction,
"NODE_FILE_RESTORE": NodeFileRestoreAction,
"NODE_FILE_CORRUPT": NodeFileCorruptAction,
"NODE_FOLDER_SCAN": NodeFolderScanAction,
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
@@ -387,7 +396,7 @@ class ActionManager:
def __init__(
self,
session: PrimaiteSession, # reference to session for looking up stuff
session: "PrimaiteSession", # reference to session for looking up stuff
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
max_folders_per_node: int = 2, # allows calculating shape
@@ -400,7 +409,7 @@ class ActionManager:
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
self.session: PrimaiteSession = session
self.session: "PrimaiteSession" = session
self.sim: Simulation = self.session.simulation
self.node_uuids: List[str] = node_uuids
self.protocols: List[str] = protocols
@@ -417,7 +426,8 @@ class ActionManager:
for nic_uuid, nic_obj in nics.items():
self.ip_address_list.append(nic_obj.ip_address)
action_args = {
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
"num_nodes": len(node_uuids),
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
@@ -427,10 +437,21 @@ class ActionManager:
"num_protocols": len(self.protocols),
"num_ports": len(self.protocols),
"num_ips": len(self.ip_address_list),
"max_acl_rules":max_acl_rules,
"max_nics_per_node": max_nics_per_node,
}
self.actions: Dict[str, AbstractAction] = {}
for act_type in actions:
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args)
for act_spec in actions:
# each action is provided into the action space config like this:
# - type: ACTION_TYPE
# options:
# option_1: value1
# option_2: value2
# where `type` decides which AbstractAction subclass should be used
# and `options` is an optional dict of options to pass to the init method of the action class
act_type = act_spec.get('type')
act_options = act_spec.get('options', {})
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
self.action_map: Dict[int, Tuple[str, Dict]] = {}
"""
@@ -448,15 +469,41 @@ class ActionManager:
def _enumerate_actions(
self,
) -> Dict[int, Tuple[AbstractAction, Dict]]:
) -> Dict[int, Tuple[str, Dict]]:
"""Generate a list of all the possible actions that could be taken.
This enumerates all actions all combinations of parametes you could choose for those actions. The output
of this function is intended to populate the self.action_map parameter in the situation where the user provides
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
The enumeration relies on the Actions' `shape` attribute.
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
An example output could be:
{0: ("DONOTHING", {'dummy': 0}),
1: ("NODE_OS_SCAN", {'node_id': 0}),
2: ("NODE_OS_SCAN", {'node_id': 1}),
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
... #etc...
}
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
"""
all_action_possibilities = []
for action in self.actions.values():
param_names = (list(action.shape.keys()),)
for act_name, action in self.actions.items():
param_names = list(action.shape.keys())
num_possibilities = list(action.shape.values())
possibilities = [range(n) for n in num_possibilities]
itertools.product(action.shape.values())
all_action_possibilities.append((action, {}))
param_combinations = list(itertools.product(*possibilities))
all_action_possibilities.extend(
[
(
act_name, {param_names[i]:param_combinations[j][i] for i in range(len(param_names))}
) for j in range(len(param_combinations))]
)
return {i:p for i,p in enumerate(all_action_possibilities)}
def get_action(self, action: int) -> Tuple[str, Dict]:
"""Produce action in CAOS format"""
@@ -517,21 +564,16 @@ class ActionManager:
return nics[nic_idx]
@classmethod
def from_config(cls, session: PrimaiteSession, cfg: Dict) -> "ActionManager":
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
obj = cls(
session=session,
actions=cfg["action_list"],
node_uuids=cfg["options"]["nodes"],
max_folders_per_node=cfg["options"]["max_folders_per_node"],
max_files_per_folder=cfg["options"]["max_files_per_folder"],
max_services_per_node=cfg["options"]["max_services_per_node"],
max_nics_per_node=cfg["options"]["max_nics_per_node"],
max_acl_rules=cfg["options"]["max_acl_rules"],
max_X=cfg["options"]["max_X"],
protocols=session.options.ports,
ports=session.options.protocols,
# node_uuids=cfg["options"]["node_uuids"],
**cfg['options'],
protocols=session.options.protocols,
ports=session.options.ports,
ip_address_list=None,
act_map=cfg["action_map"],
act_map=cfg.get("action_map"),
)
return obj

View File

@@ -2,7 +2,7 @@
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
# If you disagree, make a comment in the PR review and we can discuss
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, TypeAlias, Union
from typing import Any, Dict, List, Optional, Tuple, TypeAlias, Union
import numpy as np
@@ -41,17 +41,17 @@ class AbstractAgent(ABC):
return self.reward_function.calculate(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None):
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("NODE", "SERVICE", "SCAN", "<fake-node-sid>", "<fake-service-sid>")
return ("DO_NOTHING", {} )
@abstractmethod
def format_request(self, action) -> List[str]:
def format_request(self, action:Tuple[str,Dict], options:Dict[str, int]) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
return ["network", "nodes", "<fake-node-uuid>", "file_system", "folder", "root", "scan"]
request = self.action_space.form_request(action_identifier=action, action_options=options)
return request
class AbstractScriptedAgent(AbstractAgent):
@@ -63,8 +63,8 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None):
return self.action_space.space.sample()
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
return self.action_space.get_action(self.action_space.space.sample())
class AbstractGATEAgent(AbstractAgent):

View File

@@ -5,13 +5,30 @@ class AbstractReward():
def __init__(self):
...
@abstractmethod
def calculate(self, state:Dict) -> float:
return 0.3
class DummyReward(AbstractReward):
def calculate(self, state: Dict) -> float:
return -0.1
class RewardFunction():
__rew_class_identifiers:Dict[str,type[AbstractReward]] = {
"DUMMY" : DummyReward
}
def __init__(self, reward_function:AbstractReward):
self.reward: AbstractReward = reward_function
def calculate(self, state:Dict) -> float:
return self.reward.calculate(state)
@classmethod
def from_config(cls, cfg:Dict) -> "RewardFunction":
for rew_component_cfg in cfg['reward_components']:
rew_type = rew_component_cfg['type']
rew_component = cls.__rew_class_identifiers[rew_type]()
new = cls(reward_function=rew_component)
return new

View File

@@ -11,7 +11,7 @@ from typing import Dict, List
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.observations import (
AclObservation,
FileObservation,
@@ -25,6 +25,7 @@ from primaite.game.agent.observations import (
UC2BlueObservation,
UC2RedObservation,
)
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -74,10 +75,10 @@ class PrimaiteSession:
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
agent_action = agent.get_action(agent_obs, agent_reward)
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
agent_request = agent.format_request(agent_action)
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
self.simulation.apply_action(agent_request)
@@ -88,6 +89,10 @@ class PrimaiteSession:
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
sess = cls()
sess.options = PrimaiteSessionOptions(
ports = cfg['game_config']['ports'],
protocols = cfg['game_config']['protocols'],
)
sim = sess.simulation
net = sim.network
@@ -304,13 +309,33 @@ class PrimaiteSession:
obs_space = NullObservation()
# CREATE ACTION SPACE
action_space_cfg['options']['node_uuids'] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get('options',{}).pop('nodes', {}):
if 'node_ref' in action_node_option:
node_uuid = ref_map_nodes[action_node_option['node_ref']]
action_space_cfg['options']['node_uuids'].append(node_uuid)
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
# dynamically, so we have to translate node references to uuids before passing this config on.
if 'action_list' in action_space_cfg:
for action_config in action_space_cfg['action_list']:
if 'options' in action_config:
if 'target_router_ref' in action_config['options']:
_target = action_config['options']['target_router_ref']
action_config['options']['target_router_uuid'] = ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
...
new_agent = RandomAgent(action_space=action_space, observation_space=obs_space, reward_function=rew_function)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
...
elif agent_type == "RedDatabaseCorruptingAgent":
@@ -318,4 +343,5 @@ class PrimaiteSession:
else:
print("agent type not found")
return sess

View File

@@ -27,6 +27,7 @@ class Simulation(SimComponent):
am.add_action("network", Action(func=self.network._action_manager))
# pass through domain actions to the domain object
am.add_action("domain", Action(func=self.domain._action_manager))
am.add_action("do_nothing", Action(func=lambda request, context: ()))
return am
def describe_state(self) -> Dict: