Add documentation

This commit is contained in:
Marek Wolan
2023-10-19 01:56:40 +01:00
parent 027addc485
commit e0f8c3c5ea
6 changed files with 445 additions and 133 deletions

View File

@@ -2,10 +2,10 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 2
n_learn_steps: 128
n_eval_episodes: 2
n_eval_steps: 128
n_learn_episodes: 1
n_learn_steps: 8
n_eval_episodes: 0
n_eval_steps: 8
game_config:
@@ -451,6 +451,8 @@ game_config:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:

View File

@@ -0,0 +1 @@
"""PrimAITE Game Layer."""

View File

@@ -1,6 +1,16 @@
"""
This module contains the ActionManager class which belongs to the Agent class.
An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of
AbstractAction. The ActionManager is responsible for:
1. Creating the action space from a list of action types.
2. Converting an integer action choice into a specific action and parameter choice.
3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This
ensures that requests conform to the simulator's request format.
"""
import itertools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
@@ -13,22 +23,9 @@ if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
class ExecutionDefiniton(ABC):
"""
Converter from actions to simulator requests.
Allows adding extra data/context that defines in more detail what an action means.
"""
"""
Examples:
('node', 'service', 'scan', 2, 0) means scan the first service on node index 2
-> ['network', 'nodes', <node-idx-2-uuid>, 'services', <svc-idx-0-uuid>, 'scan'w]
"""
...
class AbstractAction(ABC):
"""Base class for actions."""
@abstractmethod
def __init__(self, manager: "ActionManager", **kwargs) -> None:
"""
@@ -46,6 +43,8 @@ class AbstractAction(ABC):
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
align with the keyword args of the form_request method."""
self.manager: ActionManager = manager
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
objects."""
@abstractmethod
def form_request(self) -> List[str]:
@@ -54,6 +53,8 @@ class AbstractAction(ABC):
class DoNothingAction(AbstractAction):
"""Action which does nothing. This is here to allow agents to be idle if they choose to."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
self.name = "DONOTHING"
@@ -65,6 +66,7 @@ class DoNothingAction(AbstractAction):
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self, **kwargs) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
@@ -77,12 +79,13 @@ class NodeServiceAbstractAction(AbstractAction):
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str
def form_request(self, node_id: int, service_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
if node_uuid is None or service_uuid is None:
@@ -91,54 +94,77 @@ class NodeServiceAbstractAction(AbstractAction):
class NodeServiceScanAction(NodeServiceAbstractAction):
"""Action which scans a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
"""Action which stops a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
"""Action which starts a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
"""Action which pauses a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
"""Action which resumes a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
"""Action which restarts a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
"""Action which disables a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
"""Action which enables a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
class NodeFolderAbstractAction(AbstractAction):
"""
Base class for folder actions.
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager=manager)
@@ -146,6 +172,7 @@ class NodeFolderAbstractAction(AbstractAction):
self.verb: str
def form_request(self, node_id: int, folder_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
if node_uuid is None or folder_uuid is None:
@@ -154,30 +181,44 @@ class NodeFolderAbstractAction(AbstractAction):
class NodeFolderScanAction(NodeFolderAbstractAction):
"""Action which scans a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
"""Action which checks the hash of a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction):
"""Action which repairs a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction):
"""Action which restores a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "restore"
class NodeFileAbstractAction(AbstractAction):
"""Abstract base class for file actions.
Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit
from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager=manager)
@@ -185,6 +226,7 @@ class NodeFileAbstractAction(AbstractAction):
self.verb: str
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
@@ -194,42 +236,60 @@ class NodeFileAbstractAction(AbstractAction):
class NodeFileScanAction(NodeFileAbstractAction):
"""Action which scans a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
"""Action which checks the hash of a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
"""Action which deletes a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
"""Action which repairs a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
"""Action which restores a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction):
"""Action which corrupts a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "corrupt"
class NodeAbstractAction(AbstractAction):
"""
Abstract base class for node actions.
Any action which applies to a node and uses node_id as its only parameter can inherit from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
@@ -237,35 +297,46 @@ class NodeAbstractAction(AbstractAction):
self.verb: str
def form_request(self, node_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
return ["network", "node", node_uuid, self.verb]
class NodeOSScanAction(NodeAbstractAction):
"""Action which scans a node's OS."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "scan"
class NodeShutdownAction(NodeAbstractAction):
"""Action which shuts down a node."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "shutdown"
class NodeStartupAction(NodeAbstractAction):
"""Action which starts up a node."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "startup"
class NodeResetAction(NodeAbstractAction):
"""Action which resets a node."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "reset"
class NetworkACLAddRuleAction(AbstractAction):
"""Action which adds a rule to a router's ACL."""
def __init__(
self,
manager: "ActionManager",
@@ -276,6 +347,21 @@ class NetworkACLAddRuleAction(AbstractAction):
num_protocols: int,
**kwargs,
) -> None:
"""Init method for NetworkACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param target_router_uuid: UUID of the router to which the ACL rule should be added.
:type target_router_uuid: str
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
:type num_ips: int
:param num_ports: Number of ports in the simulation.
:type num_ports: int
:param num_protocols: Number of protocols in the simulation.
:type num_protocols: int
"""
super().__init__(manager=manager)
num_permissions = 3
self.shape: Dict[str, int] = {
@@ -290,8 +376,16 @@ class NetworkACLAddRuleAction(AbstractAction):
self.target_router_uuid: str = target_router_uuid
def form_request(
self, position, permission, source_ip_id, dest_ip_id, source_port_id, dest_port_id, protocol_id
self,
position: int,
permission: int,
source_ip_id: int,
dest_ip_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
@@ -354,22 +448,51 @@ class NetworkACLAddRuleAction(AbstractAction):
class NetworkACLRemoveRuleAction(AbstractAction):
"""Action which removes a rule from a router's ACL."""
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
"""Init method for NetworkACLRemoveRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param target_router_uuid: UUID of the router from which the ACL rule should be removed.
:type target_router_uuid: str
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
self.target_router_uuid: str = target_router_uuid
def form_request(self, position: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
class NetworkNICAbstractAction(AbstractAction):
"""
Abstract base class for NIC actions.
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
class.
"""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkNICAbstractAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param num_nodes: Number of nodes in the simulation.
:type num_nodes: int
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str
def form_request(self, node_id: int, nic_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
if node_uuid is None or nic_uuid is None:
@@ -385,45 +508,24 @@ class NetworkNICAbstractAction(AbstractAction):
class NetworkNICEnableAction(NetworkNICAbstractAction):
"""Action which enables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "enable"
class NetworkNICDisableAction(NetworkNICAbstractAction):
"""Action which disables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
# class NetworkNICDisableAction(AbstractAction):
# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
# super().__init__(manager=manager)
# self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
# def form_request(self, node_id: int, nic_id: int) -> List[str]:
# return [
# "network",
# "node",
# self.manager.get_node_uuid_by_idx(node_idx=node_id),
# "nic",
# self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
# "disable",
# ]
class ActionManager:
# let the action manager handle the conversion of action spaces into a single discrete integer space.
#
"""Class which manages the action space for an agent."""
# when action space is created, it will take subspaces and generate an action map by enumerating all possibilities,
# BUT, the action map can be provided in the config, in which case it will use that.
# action map is basically just a mapping between integer and CAOS action (incl. parameter values)
# for example the action map can be:
# 0: DONOTHING
# 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0
# 2: ......
__act_class_identifiers: Dict[str, type] = {
"DONOTHING": DoNothingAction,
"NODE_SERVICE_SCAN": NodeServiceScanAction,
@@ -453,6 +555,7 @@ class ActionManager:
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""
def __init__(
self,
@@ -469,6 +572,33 @@ class ActionManager:
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
"""Init method for ActionManager.
:param session: Reference to the session to which the agent belongs.
:type session: PrimaiteSession
:param actions: List of action types which should be made available to the agent.
:type actions: List[str]
:param node_uuids: List of node UUIDs that this agent can act on.
:type node_uuids: List[str]
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
:type max_folders_per_node: int
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
:type max_files_per_folder: int
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
:type max_services_per_node: int
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
:type max_nics_per_node: int
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
:type max_acl_rules: int
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
:type protocols: List[str]
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
:type ports: List[str]
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_address_list: Optional[List[str]]
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.session: "PrimaiteSession" = session
self.sim: Simulation = self.session.simulation
self.node_uuids: List[str] = node_uuids
@@ -578,18 +708,48 @@ class ActionManager:
@property
def space(self) -> spaces.Space:
"""Return the gymnasium action space for this agent."""
return spaces.Discrete(len(self.action_map))
def get_node_uuid_by_idx(self, node_idx):
def get_node_uuid_by_idx(self, node_idx: int) -> str:
"""Get the node UUID corresponding to the given index.
:param node_idx: The index of the node to retrieve.
:type node_idx: int
:return: The node UUID.
:rtype: str
"""
return self.node_uuids[node_idx]
def get_folder_uuid_by_idx(self, node_idx, folder_idx) -> Optional[str]:
def get_folder_uuid_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
"""Get the folder UUID corresponding to the given node and folder indices.
:param node_idx: The index of the node.
:type node_idx: int
:param folder_idx: The index of the folder on the node.
:type folder_idx: int
:return: The UUID of the folder. Or None if the node has fewer folders than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]:
def get_file_uuid_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
"""Get the file UUID corresponding to the given node, folder, and file indices.
:param node_idx: The index of the node.
:type node_idx: int
:param folder_idx: The index of the folder on the node.
:type folder_idx: int
:param file_idx: The index of the file in the folder.
:type file_idx: int
:return: The UUID of the file. Or None if the node has fewer folders than the given index, or the folder has
fewer files than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
@@ -599,22 +759,64 @@ class ActionManager:
file_uuids = list(folder.files.keys())
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]:
def get_service_uuid_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
"""Get the service UUID corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param service_idx: The index of the service on the node.
:type service_idx: int
:return: The UUID of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
"""Get the internet protocol corresponding to the given index.
:param protocol_idx: The index of the protocol to retrieve.
:type protocol_idx: int
:return: The protocol.
:rtype: str
"""
return self.protocols[protocol_idx]
def get_ip_address_by_idx(self, ip_idx: int) -> str:
"""
Get the IP address corresponding to the given index.
:param ip_idx: The index of the IP address to retrieve.
:type ip_idx: int
:return: The IP address.
:rtype: str
"""
return self.ip_address_list[ip_idx]
def get_port_by_idx(self, port_idx: int) -> str:
"""
Get the port corresponding to the given index.
:param port_idx: The index of the port to retrieve.
:type port_idx: int
:return: The port.
:rtype: str
"""
return self.ports[port_idx]
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
"""
Get the NIC UUID corresponding to the given node and NIC indices.
:param node_idx: The index of the node.
:type node_idx: int
:param nic_idx: The index of the NIC on the node.
:type nic_idx: int
:return: The NIC UUID.
:rtype: str
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node_obj = self.sim.network.nodes[node_uuid]
nics = list(node_obj.nics.keys())
@@ -624,6 +826,31 @@ class ActionManager:
@classmethod
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config definition.
The action space config supports the following three sections:
1. ``action_list``
``action_list`` contians a list action components which need to be included in the action space.
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
which will be passed to the action class's __init__ method during initialisation.
2. ``action_map``
Since the agent uses a discrete action space which acts as a flattened version of the component-based
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
3. ``options``
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
These options are used to calculate the shape of the action space, and to provide additional information
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
:param session: The Primaite Session to which the agent belongs.
:type session: PrimaiteSession
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
obj = cls(
session=session,
actions=cfg["action_list"],

View File

@@ -1,6 +1,4 @@
# TODO: remove this comment... This is just here to point out that I've named this 'actor' rather than 'agent'
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
# If you disagree, make a comment in the PR review and we can discuss
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypeAlias, Union

View File

@@ -1,8 +1,9 @@
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
@@ -10,14 +11,13 @@ if TYPE_CHECKING:
class AbstractReward:
@abstractmethod
def calculate(self, state: Dict) -> float:
return 0.0
@classmethod
@abstractmethod
def from_config(cls, config:dict, session:"PrimaiteSession") -> "AbstractReward":
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
return cls()
@@ -26,16 +26,26 @@ class DummyReward(AbstractReward):
return 0.0
@classmethod
def from_config(cls, config: dict, session:"PrimaiteSession") -> "DummyReward":
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
return cls()
class DatabaseFileIntegrity(AbstractReward):
def __init__(self, node_uuid:str, folder_name:str, file_name:str) -> None:
self.location_in_state = ["network", "nodes", node_uuid, "file_system", "folders",folder_name, "files", file_name]
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
self.location_in_state = [
"network",
"nodes",
node_uuid,
"file_system",
"folders",
folder_name,
"files",
file_name,
]
def calculate(self, state: Dict) -> float:
database_file_state = access_from_nested_dict(state, self.location_in_state)
health_status = database_file_state['health_status']
health_status = database_file_state["health_status"]
if health_status == "corrupted":
return -1
elif health_status == "good":
@@ -48,29 +58,39 @@ class DatabaseFileIntegrity(AbstractReward):
node_ref = config.get("node_ref")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not (node_ref):
_LOGGER.error(f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified")
return DummyReward() #TODO: better error handling
if not node_ref:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not folder_name:
_LOGGER.error(f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified")
return DummyReward() # TODO: better error handling
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not file_name:
_LOGGER.error(f"{cls.__name__} could not be initialised from config because file_name parameter was not specified")
return DummyReward() # TODO: better error handling
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = session.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation")
return DummyReward() # TODO: better error handling
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation"
)
return DummyReward() # TODO: better error handling
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
return cls(node_uuid = node_uuid, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
def __init__(self, node_uuid:str, service_uuid:str) -> None:
self.location_in_state = ['network','nodes', node_uuid, 'services', service_uuid]
def __init__(self, node_uuid: str, service_uuid: str) -> None:
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
def calculate(self, state: Dict) -> float:
web_service_state = access_from_nested_dict(state, self.location_in_state)
most_recent_return_code = web_service_state['most_recent_return_code']
most_recent_return_code = web_service_state["most_recent_return_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1
elif most_recent_return_code == 404:
@@ -85,13 +105,13 @@ class WebServer404Penalty(AbstractReward):
if not (node_ref and service_ref):
msg = f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not found in reward config."
_LOGGER.warn(msg)
return DummyReward() #TODO: should we error out with incorrect inputs? Probably!
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = session.ref_map_nodes[node_ref]
service_uuid = session.ref_map_services[service_ref].uuid
if not (node_uuid and service_uuid):
msg = f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not found in the simulator."
_LOGGER.warn(msg)
return DummyReward() # TODO: consider erroring here as well
return DummyReward() # TODO: consider erroring here as well
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
@@ -101,13 +121,13 @@ class RewardFunction:
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
}
}
def __init__(self):
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
def regsiter_component(self, component:AbstractReward, weight:float=1.0) -> None:
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
self.reward_components.append((component, weight))
def calculate(self, state: Dict) -> float:
@@ -124,8 +144,8 @@ class RewardFunction:
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg.get("weight",1.0)
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.__rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get('options',{}), session=session)
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
new.regsiter_component(component=rew_instance, weight=weight)
return new

View File

@@ -1,39 +1,17 @@
# What do? Be an entry point for using PrimAITE
# 1. parse monoconfig
# 2. craete simulation
# 3. create actors and configure their actions/observations/rewards/ anything else
# 4. Create connection with ARCD GATE
# 5. idk
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple
from arcd_gate.client.gate_client import ActType, GATEClient
from gymnasium import spaces
from gymnasium.spaces.utils import flatten, flatten_space, unflatten
from gymnasium.core import ObsType, ActType
import numpy as np
from gymnasium.core import ActType, ObsType
from gymnasium.spaces.utils import flatten, flatten_space
from pydantic import BaseModel
from primaite import getLogger
from primaite.game.agent.GATE_agents import GATERLAgent
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.observations import (
AclObservation,
FileObservation,
FolderObservation,
ICSObservation,
LinkObservation,
NicObservation,
NodeObservation,
NullObservation,
ObservationSpace,
ServiceObservation,
UC2BlueObservation,
UC2GreenObservation,
UC2RedObservation,
)
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
@@ -50,53 +28,75 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from arcd_gate.client.gate_client import GATEClient, ActType
from numpy import ndarray
_LOGGER = getLogger(__name__)
class PrimaiteGATEClient(GATEClient):
def __init__(self, parent_session:"PrimaiteSession", service_port: int = 50000):
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
"""Create a new GATE client for PrimAITE.
:param parent_session: The parent session object.
:type parent_session: PrimaiteSession
:param service_port: The port on which the GATE service is running.
:type service_port: int, optional"""
super().__init__(service_port=service_port)
self.parent_session:"PrimaiteSession" = parent_session
self.parent_session: "PrimaiteSession" = parent_session
@property
def rl_framework(self) -> str:
"""The reinforcement learning framework to use."""
return self.parent_session.training_options.rl_framework
@property
def rl_algorithm(self) -> str:
"""The reinforcement learning algorithm to use."""
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> int | None:
"""The seed to use for the environment's random number generator."""
return self.parent_session.training_options.seed
@property
def n_learn_episodes(self) -> int:
"""The number of episodes in each learning run."""
return self.parent_session.training_options.n_learn_episodes
@property
def n_learn_steps(self) -> int:
"""The number of steps in each learning episode."""
return self.parent_session.training_options.n_learn_steps
@property
def n_eval_episodes(self) -> int:
"""The number of episodes in each evaluation run."""
return self.parent_session.training_options.n_eval_episodes
@property
def n_eval_steps(self) -> int:
"""The number of steps in each evaluation episode."""
return self.parent_session.training_options.n_eval_steps
@property
def action_space(self) -> spaces.Space:
"""The gym action space of the agent."""
return self.parent_session.rl_agent.action_space.space
@property
def observation_space(self) -> spaces.Space:
"""The gymnasium observation space of the agent."""
return flatten_space(self.parent_session.rl_agent.observation_space.space)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
"""Take a step in the environment.
This method is called by GATE to advance the simulation by one timestep.
:param action: The agent's action.
:type action: ActType
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
:rtype: Tuple[ObsType, float, bool, bool, Dict]
"""
self.parent_session.rl_agent.most_recent_action = action
self.parent_session.step()
state = self.parent_session.simulation.describe_state()
@@ -108,8 +108,19 @@ class PrimaiteGATEClient(GATEClient):
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
"""Reset the environment.
This method is called when the environment is initialized and at the end of each episode.
:param seed: The seed to use for the environment's random number generator.
:type seed: int, optional
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
compatibility with GATE.
:type options: dict[str, Any], optional
:return: The initial observation and an empty info dictionary.
:rtype: Tuple[ObsType, Dict]
"""
self.parent_session.reset()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
@@ -117,44 +128,78 @@ class PrimaiteGATEClient(GATEClient):
return obs, {}
def close(self):
"""Close the session, this will stop the gate client and close the simulation."""
self.parent_session.close()
class PrimaiteSessionOptions(BaseModel):
"""Global options which are applicable to all of the agents in the game.
Currently this is used to restrict which ports and protocols exist in the world of the simulation."""
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
rl_framework:str
rl_algorithm:str
seed:Optional[int]
n_learn_episodes:int
n_learn_steps:int
n_eval_episodes:int
n_eval_steps:int
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
rl_framework: str
rl_algorithm: str
seed: Optional[int]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int
n_eval_steps: int
class PrimaiteSession:
"""
The main entrypoint for PrimAITE sessions, this coordinates a simulation, agents, and connections to ARCD GATE.
"""
def __init__(self):
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agent: AbstractAgent
# which of the agents should be used for sending RL data to GATE client?
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
def start_session(self, opts="TODO..."):
"""Commence the session, this gives the gate client control over the simulation/agent loop."""
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
self.gate_client.start()
def step(self):
"""
Perform one step of the simulation/agent loop.
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
agent. The steps are as follows:
1. The simulation state is updated.
2. The simulation state is sent to each agent.
3. Each agent converts the state to an observation and calculates a reward.
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
@@ -192,19 +237,36 @@ class PrimaiteSession:
self.step_counter += 1
def reset(self):
pass
"""Reset the session, this will reset the simulation."""
return NotImplemented
def close(self):
pass
"""Close the session, this will stop the gate client and close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent. Used by GATE.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
# TODO: create documentation page and add links to it here.
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
"""
sess = cls()
sess.options = PrimaiteSessionOptions(
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg['training_config'])
sess.training_options = TrainingOptions(**cfg["training_config"])
sim = sess.simulation
net = sim.network
@@ -295,7 +357,11 @@ class PrimaiteSession:
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[node_ref] = new_node.uuid # TODO: fix incosistency with service and link. Node gets added by uuid, but service gets reference to object
sess.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
# 2. create links between nodes
for link_cfg in links_cfg:
@@ -314,12 +380,10 @@ class PrimaiteSession:
# 3. create agents
game_cfg = cfg["game_config"]
ports_cfg = game_cfg["ports"]
protocols_cfg = game_cfg["protocols"]
agents_cfg = game_cfg["agents"]
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"]
agent_ref = agent_cfg["ref"] # noqa: F841
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]