Add documentation
This commit is contained in:
@@ -2,10 +2,10 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 2
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 2
|
||||
n_eval_steps: 128
|
||||
n_learn_episodes: 1
|
||||
n_learn_steps: 8
|
||||
n_eval_episodes: 0
|
||||
n_eval_steps: 8
|
||||
|
||||
|
||||
game_config:
|
||||
@@ -451,6 +451,8 @@ game_config:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""PrimAITE Game Layer."""
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
"""
|
||||
This module contains the ActionManager class which belongs to the Agent class.
|
||||
|
||||
An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of
|
||||
AbstractAction. The ActionManager is responsible for:
|
||||
1. Creating the action space from a list of action types.
|
||||
2. Converting an integer action choice into a specific action and parameter choice.
|
||||
3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This
|
||||
ensures that requests conform to the simulator's request format.
|
||||
"""
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
@@ -13,22 +23,9 @@ if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class ExecutionDefiniton(ABC):
|
||||
"""
|
||||
Converter from actions to simulator requests.
|
||||
|
||||
Allows adding extra data/context that defines in more detail what an action means.
|
||||
"""
|
||||
|
||||
"""
|
||||
Examples:
|
||||
('node', 'service', 'scan', 2, 0) means scan the first service on node index 2
|
||||
-> ['network', 'nodes', <node-idx-2-uuid>, 'services', <svc-idx-0-uuid>, 'scan'w]
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
"""
|
||||
@@ -46,6 +43,8 @@ class AbstractAction(ABC):
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
@@ -54,6 +53,8 @@ class AbstractAction(ABC):
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
"""Action which does nothing. This is here to allow agents to be idle if they choose to."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.name = "DONOTHING"
|
||||
@@ -65,6 +66,7 @@ class DoNothingAction(AbstractAction):
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self, **kwargs) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
@@ -77,12 +79,13 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes, num_services, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
|
||||
if node_uuid is None or service_uuid is None:
|
||||
@@ -91,54 +94,77 @@ class NodeServiceAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
"""Action which scans a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
"""Action which stops a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
"""Action which starts a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
"""Action which pauses a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
"""Action which resumes a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
"""Action which restarts a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
"""Action which disables a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
"""Action which enables a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
|
||||
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
@@ -146,6 +172,7 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
if node_uuid is None or folder_uuid is None:
|
||||
@@ -154,30 +181,44 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
"""Action which checks the hash of a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
"""Action which repairs a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
"""Action which restores a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "restore"
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction):
|
||||
"""Abstract base class for file actions.
|
||||
|
||||
Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit
|
||||
from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
@@ -185,6 +226,7 @@ class NodeFileAbstractAction(AbstractAction):
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
@@ -194,42 +236,60 @@ class NodeFileAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction):
|
||||
"""Action which scans a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction):
|
||||
"""Action which checks the hash of a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "checkhash"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
"""Action which deletes a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "delete"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
"""Action which repairs a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "repair"
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction):
|
||||
"""Action which restores a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction):
|
||||
"""Action which corrupts a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "corrupt"
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for node actions.
|
||||
|
||||
Any action which applies to a node and uses node_id as its only parameter can inherit from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
@@ -237,35 +297,46 @@ class NodeAbstractAction(AbstractAction):
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
return ["network", "node", node_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction):
|
||||
"""Action which scans a node's OS."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction):
|
||||
"""Action which shuts down a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
"""Action which starts up a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
"""Action which resets a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "reset"
|
||||
|
||||
|
||||
class NetworkACLAddRuleAction(AbstractAction):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
@@ -276,6 +347,21 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for NetworkACLAddRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_uuid: UUID of the router to which the ACL rule should be added.
|
||||
:type target_router_uuid: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
:param num_ips: Number of IP addresses in the simulation.
|
||||
:type num_ips: int
|
||||
:param num_ports: Number of ports in the simulation.
|
||||
:type num_ports: int
|
||||
:param num_protocols: Number of protocols in the simulation.
|
||||
:type num_protocols: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 3
|
||||
self.shape: Dict[str, int] = {
|
||||
@@ -290,8 +376,16 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(
|
||||
self, position, permission, source_ip_id, dest_ip_id, source_port_id, dest_port_id, protocol_id
|
||||
self,
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
dest_ip_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
@@ -354,22 +448,51 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
|
||||
|
||||
class NetworkACLRemoveRuleAction(AbstractAction):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for NetworkACLRemoveRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_uuid: UUID of the router from which the ACL rule should be removed.
|
||||
:type target_router_uuid: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(self, position: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
|
||||
|
||||
|
||||
class NetworkNICAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
|
||||
class.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param num_nodes: Number of nodes in the simulation.
|
||||
:type num_nodes: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
|
||||
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
|
||||
if node_uuid is None or nic_uuid is None:
|
||||
@@ -385,45 +508,24 @@ class NetworkNICAbstractAction(AbstractAction):
|
||||
|
||||
|
||||
class NetworkNICEnableAction(NetworkNICAbstractAction):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
# class NetworkNICDisableAction(AbstractAction):
|
||||
# def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
# super().__init__(manager=manager)
|
||||
# self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
|
||||
# def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
# return [
|
||||
# "network",
|
||||
# "node",
|
||||
# self.manager.get_node_uuid_by_idx(node_idx=node_id),
|
||||
# "nic",
|
||||
# self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id),
|
||||
# "disable",
|
||||
# ]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
# let the action manager handle the conversion of action spaces into a single discrete integer space.
|
||||
#
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
# when action space is created, it will take subspaces and generate an action map by enumerating all possibilities,
|
||||
# BUT, the action map can be provided in the config, in which case it will use that.
|
||||
|
||||
# action map is basically just a mapping between integer and CAOS action (incl. parameter values)
|
||||
# for example the action map can be:
|
||||
# 0: DONOTHING
|
||||
# 1: NODE, FILE, SCAN, NODEID=2, FOLDERID=1, FILEID=0
|
||||
# 2: ......
|
||||
__act_class_identifiers: Dict[str, type] = {
|
||||
"DONOTHING": DoNothingAction,
|
||||
"NODE_SERVICE_SCAN": NodeServiceScanAction,
|
||||
@@ -453,6 +555,7 @@ class ActionManager:
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -469,6 +572,33 @@ class ActionManager:
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param session: Reference to the session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param actions: List of action types which should be made available to the agent.
|
||||
:type actions: List[str]
|
||||
:param node_uuids: List of node UUIDs that this agent can act on.
|
||||
:type node_uuids: List[str]
|
||||
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
|
||||
:type max_folders_per_node: int
|
||||
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
|
||||
:type max_files_per_folder: int
|
||||
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
|
||||
:type max_services_per_node: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
|
||||
:type max_nics_per_node: int
|
||||
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
|
||||
:type max_acl_rules: int
|
||||
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_address_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
@@ -578,18 +708,48 @@ class ActionManager:
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
def get_node_uuid_by_idx(self, node_idx):
|
||||
def get_node_uuid_by_idx(self, node_idx: int) -> str:
|
||||
"""Get the node UUID corresponding to the given index.
|
||||
|
||||
:param node_idx: The index of the node to retrieve.
|
||||
:type node_idx: int
|
||||
:return: The node UUID.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.node_uuids[node_idx]
|
||||
|
||||
def get_folder_uuid_by_idx(self, node_idx, folder_idx) -> Optional[str]:
|
||||
def get_folder_uuid_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
"""Get the folder UUID corresponding to the given node and folder indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:return: The UUID of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
|
||||
|
||||
def get_file_uuid_by_idx(self, node_idx, folder_idx, file_idx) -> Optional[str]:
|
||||
def get_file_uuid_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
"""Get the file UUID corresponding to the given node, folder, and file indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:param file_idx: The index of the file in the folder.
|
||||
:type file_idx: int
|
||||
:return: The UUID of the file. Or None if the node has fewer folders than the given index, or the folder has
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
@@ -599,22 +759,64 @@ class ActionManager:
|
||||
file_uuids = list(folder.files.keys())
|
||||
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
|
||||
|
||||
def get_service_uuid_by_idx(self, node_idx, service_idx) -> Optional[str]:
|
||||
def get_service_uuid_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
"""Get the service UUID corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param service_idx: The index of the service on the node.
|
||||
:type service_idx: int
|
||||
:return: The UUID of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
service_uuids = list(node.services.keys())
|
||||
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
"""Get the internet protocol corresponding to the given index.
|
||||
|
||||
:param protocol_idx: The index of the protocol to retrieve.
|
||||
:type protocol_idx: int
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
"""
|
||||
Get the IP address corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP address to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
|
||||
:param port_idx: The index of the port to retrieve.
|
||||
:type port_idx: int
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
|
||||
"""
|
||||
Get the NIC UUID corresponding to the given node and NIC indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param nic_idx: The index of the NIC on the node.
|
||||
:type nic_idx: int
|
||||
:return: The NIC UUID.
|
||||
:rtype: str
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = list(node_obj.nics.keys())
|
||||
@@ -624,6 +826,31 @@ class ActionManager:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
The action space config supports the following three sections:
|
||||
1. ``action_list``
|
||||
``action_list`` contians a list action components which need to be included in the action space.
|
||||
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
|
||||
which will be passed to the action class's __init__ method during initialisation.
|
||||
2. ``action_map``
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param session: The Primaite Session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
obj = cls(
|
||||
session=session,
|
||||
actions=cfg["action_list"],
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# TODO: remove this comment... This is just here to point out that I've named this 'actor' rather than 'agent'
|
||||
# That's because I want to point out that this is disctinct from 'agent' in the reinforcement learning sense of the word
|
||||
# If you disagree, make a comment in the PR review and we can discuss
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -10,14 +11,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict) -> float:
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config:dict, session:"PrimaiteSession") -> "AbstractReward":
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
|
||||
return cls()
|
||||
|
||||
|
||||
@@ -26,16 +26,26 @@ class DummyReward(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, session:"PrimaiteSession") -> "DummyReward":
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
|
||||
return cls()
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
def __init__(self, node_uuid:str, folder_name:str, file_name:str) -> None:
|
||||
self.location_in_state = ["network", "nodes", node_uuid, "file_system", "folders",folder_name, "files", file_name]
|
||||
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
node_uuid,
|
||||
"file_system",
|
||||
"folders",
|
||||
folder_name,
|
||||
"files",
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
health_status = database_file_state['health_status']
|
||||
health_status = database_file_state["health_status"]
|
||||
if health_status == "corrupted":
|
||||
return -1
|
||||
elif health_status == "good":
|
||||
@@ -48,29 +58,39 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
node_ref = config.get("node_ref")
|
||||
folder_name = config.get("folder_name")
|
||||
file_name = config.get("file_name")
|
||||
if not (node_ref):
|
||||
_LOGGER.error(f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified")
|
||||
return DummyReward() #TODO: better error handling
|
||||
if not node_ref:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not folder_name:
|
||||
_LOGGER.error(f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified")
|
||||
return DummyReward() # TODO: better error handling
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not file_name:
|
||||
_LOGGER.error(f"{cls.__name__} could not be initialised from config because file_name parameter was not specified")
|
||||
return DummyReward() # TODO: better error handling
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation")
|
||||
return DummyReward() # TODO: better error handling
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
|
||||
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
return cls(node_uuid = node_uuid, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
def __init__(self, node_uuid:str, service_uuid:str) -> None:
|
||||
self.location_in_state = ['network','nodes', node_uuid, 'services', service_uuid]
|
||||
def __init__(self, node_uuid: str, service_uuid: str) -> None:
|
||||
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
most_recent_return_code = web_service_state['most_recent_return_code']
|
||||
most_recent_return_code = web_service_state["most_recent_return_code"]
|
||||
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
|
||||
if most_recent_return_code == 200:
|
||||
return 1
|
||||
elif most_recent_return_code == 404:
|
||||
@@ -85,13 +105,13 @@ class WebServer404Penalty(AbstractReward):
|
||||
if not (node_ref and service_ref):
|
||||
msg = f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not found in reward config."
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() #TODO: should we error out with incorrect inputs? Probably!
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
service_uuid = session.ref_map_services[service_ref].uuid
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not found in the simulator."
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: consider erroring here as well
|
||||
return DummyReward() # TODO: consider erroring here as well
|
||||
|
||||
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
|
||||
|
||||
@@ -101,13 +121,13 @@ class RewardFunction:
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
|
||||
def regsiter_component(self, component:AbstractReward, weight:float=1.0) -> None:
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
@@ -124,8 +144,8 @@ class RewardFunction:
|
||||
|
||||
for rew_component_cfg in config["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
weight = rew_component_cfg.get("weight",1.0)
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.__rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get('options',{}), session=session)
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
|
||||
new.regsiter_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
|
||||
@@ -1,39 +1,17 @@
|
||||
# What do? Be an entry point for using PrimAITE
|
||||
# 1. parse monoconfig
|
||||
# 2. craete simulation
|
||||
# 3. create actors and configure their actions/observations/rewards/ anything else
|
||||
# 4. Create connection with ARCD GATE
|
||||
# 5. idk
|
||||
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from arcd_gate.client.gate_client import ActType, GATEClient
|
||||
from gymnasium import spaces
|
||||
from gymnasium.spaces.utils import flatten, flatten_space, unflatten
|
||||
from gymnasium.core import ObsType, ActType
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces.utils import flatten, flatten_space
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.GATE_agents import GATERLAgent
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import (
|
||||
AclObservation,
|
||||
FileObservation,
|
||||
FolderObservation,
|
||||
ICSObservation,
|
||||
LinkObservation,
|
||||
NicObservation,
|
||||
NodeObservation,
|
||||
NullObservation,
|
||||
ObservationSpace,
|
||||
ServiceObservation,
|
||||
UC2BlueObservation,
|
||||
UC2GreenObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
@@ -50,53 +28,75 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.service import Service
|
||||
|
||||
from arcd_gate.client.gate_client import GATEClient, ActType
|
||||
from numpy import ndarray
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
def __init__(self, parent_session:"PrimaiteSession", service_port: int = 50000):
|
||||
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
|
||||
"""Create a new GATE client for PrimAITE.
|
||||
|
||||
:param parent_session: The parent session object.
|
||||
:type parent_session: PrimaiteSession
|
||||
:param service_port: The port on which the GATE service is running.
|
||||
:type service_port: int, optional"""
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session:"PrimaiteSession" = parent_session
|
||||
self.parent_session: "PrimaiteSession" = parent_session
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
"""The reinforcement learning framework to use."""
|
||||
return self.parent_session.training_options.rl_framework
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
"""The reinforcement learning algorithm to use."""
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> int | None:
|
||||
"""The seed to use for the environment's random number generator."""
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
"""The number of episodes in each learning run."""
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
"""The number of steps in each learning episode."""
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
"""The number of episodes in each evaluation run."""
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
"""The number of steps in each evaluation episode."""
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
"""The gym action space of the agent."""
|
||||
return self.parent_session.rl_agent.action_space.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""The gymnasium observation space of the agent."""
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
"""Take a step in the environment.
|
||||
|
||||
This method is called by GATE to advance the simulation by one timestep.
|
||||
|
||||
:param action: The agent's action.
|
||||
:type action: ActType
|
||||
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
|
||||
:rtype: Tuple[ObsType, float, bool, bool, Dict]
|
||||
"""
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
@@ -108,8 +108,19 @@ class PrimaiteGATEClient(GATEClient):
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment.
|
||||
|
||||
This method is called when the environment is initialized and at the end of each episode.
|
||||
|
||||
:param seed: The seed to use for the environment's random number generator.
|
||||
:type seed: int, optional
|
||||
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
|
||||
compatibility with GATE.
|
||||
:type options: dict[str, Any], optional
|
||||
:return: The initial observation and an empty info dictionary.
|
||||
:rtype: Tuple[ObsType, Dict]
|
||||
"""
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
@@ -117,44 +128,78 @@ class PrimaiteGATEClient(GATEClient):
|
||||
return obs, {}
|
||||
|
||||
def close(self):
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
self.parent_session.close()
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
"""Global options which are applicable to all of the agents in the game.
|
||||
|
||||
Currently this is used to restrict which ports and protocols exist in the world of the simulation."""
|
||||
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
rl_framework:str
|
||||
rl_algorithm:str
|
||||
seed:Optional[int]
|
||||
n_learn_episodes:int
|
||||
n_learn_steps:int
|
||||
n_eval_episodes:int
|
||||
n_eval_steps:int
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
rl_framework: str
|
||||
rl_algorithm: str
|
||||
seed: Optional[int]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int
|
||||
n_eval_steps: int
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The main entrypoint for PrimAITE sessions, this coordinates a simulation, agents, and connections to ARCD GATE.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.rl_agent: AbstractAgent
|
||||
# which of the agents should be used for sending RL data to GATE client?
|
||||
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.options: PrimaiteSessionOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
|
||||
|
||||
def start_session(self, opts="TODO..."):
|
||||
"""Commence the session, this gives the gate client control over the simulation/agent loop."""
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
self.gate_client.start()
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
|
||||
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
|
||||
agent. The steps are as follows:
|
||||
1. The simulation state is updated.
|
||||
2. The simulation state is sent to each agent.
|
||||
3. Each agent converts the state to an observation and calculates a reward.
|
||||
4. Each agent chooses an action based on the observation.
|
||||
5. Each agent converts the action to a request.
|
||||
6. The simulation applies the requests.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
@@ -192,19 +237,36 @@ class PrimaiteSession:
|
||||
self.step_counter += 1
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent. Used by GATE.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
# TODO: create documentation page and add links to it here.
|
||||
|
||||
:param cfg: The config dictionary.
|
||||
:type cfg: dict
|
||||
:return: A PrimaiteSession object.
|
||||
:rtype: PrimaiteSession
|
||||
"""
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg['training_config'])
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
@@ -295,7 +357,11 @@ class PrimaiteSession:
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
sess.ref_map_nodes[node_ref] = new_node.uuid # TODO: fix incosistency with service and link. Node gets added by uuid, but service gets reference to object
|
||||
sess.ref_map_nodes[
|
||||
node_ref
|
||||
] = (
|
||||
new_node.uuid
|
||||
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
@@ -314,12 +380,10 @@ class PrimaiteSession:
|
||||
|
||||
# 3. create agents
|
||||
game_cfg = cfg["game_config"]
|
||||
ports_cfg = game_cfg["ports"]
|
||||
protocols_cfg = game_cfg["protocols"]
|
||||
agents_cfg = game_cfg["agents"]
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"]
|
||||
agent_ref = agent_cfg["ref"] # noqa: F841
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
|
||||
Reference in New Issue
Block a user