diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index cf2b618f..0194bd72 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -103,7 +103,7 @@ Similar to action space, this is defined as a list of components from the :py:mo ``reward_components`` ^^^^^^^^^^^^^^^^^^^^^ - +TODO: update description A list of reward types from :py:mod:`primaite.game.agent.rewards.RewardFunction.rew_class_identifiers` e.g. diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst index 0b2067d8..47ff6832 100644 --- a/docs/source/configuration/simulation.rst +++ b/docs/source/configuration/simulation.rst @@ -6,7 +6,7 @@ ``simulation`` ============== In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. - +# TODO: ref field is no longer real At the top level of the network are ``nodes``, ``links`` and ``airspace``. e.g. diff --git a/docs/source/how_to_guides/extensible_actions.rst b/docs/source/how_to_guides/extensible_actions.rst index 93a6cf21..4deede53 100644 --- a/docs/source/how_to_guides/extensible_actions.rst +++ b/docs/source/how_to_guides/extensible_actions.rst @@ -20,7 +20,7 @@ Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and cont #. ConfigSchema class -#. Unique Identifier +#. Unique discriminator #. `form_request` method. @@ -31,14 +31,14 @@ ConfigSchema The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file. -Unique Identifier +Unique discriminator ################# -When declaring a custom class, it must have a unique identifier string, that allows PrimAITE to generate the correct action when needed. +When declaring a custom class, it must have a unique discriminator string, that allows PrimAITE to generate the correct action when needed. .. code:: Python - class CreateDirectoryAction(AbstractAction, identifier="node_folder_create") + class CreateDirectoryAction(AbstractAction, discriminator="node_folder_create") config: CreateDirectoryAction.ConfigSchema @@ -58,7 +58,7 @@ When declaring a custom class, it must have a unique identifier string, that all config.directory_name, ] -The above action would fail pydantic validation as the identifier "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`. +The above action would fail pydantic validation as the discriminator "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`. form_request method diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst index 256e96ca..d83f83d6 100644 --- a/docs/source/how_to_guides/extensible_agents.rst +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -25,7 +25,7 @@ The core features that should be implemented in any new agent are detailed below .. code-block:: python - class ExampleAgent(AbstractAgent, identifier = "ExampleAgent"): + class ExampleAgent(AbstractAgent, discriminator = "ExampleAgent"): """An example agent for demonstration purposes.""" config: "ExampleAgent.ConfigSchema" = Field(default_factory= lambda: ExampleAgent.ConfigSchema()) @@ -64,9 +64,9 @@ The core features that should be implemented in any new agent are detailed below starting_host: "Server_1" -#. **Identifiers**: +#. **discriminators**: - All agent classes should have an ``identifier`` attribute, a unique kebab-case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent. + All agent classes should have an ``discriminator`` attribute, a unique kebab-case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent. Changes to YAML file ==================== diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst index a01b9d8f..9068f1bb 100644 --- a/docs/source/how_to_guides/extensible_rewards.rst +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -17,7 +17,7 @@ Reward classes are inherited from AbstractReward (a sub-class of Pydantic's Base Within the reward class there is a ConfigSchema class responsible for ensuring the config file data is in the correct format. This also means there is little (if no) requirement for and `__init__` method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`. -Each class requires an identifier string which is used by the ConfigSchema class to verify that it +Each class requires an discriminator string which is used by the ConfigSchema class to verify that it hasn't previously been added to the registry. Inheriting from `BaseModel` removes the need for an `__init__` method but means that object @@ -28,7 +28,7 @@ To add a new reward class follow the example below. Note that the type attribute .. code-block:: Python -class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): +class DatabaseFileIntegrity(AbstractReward, discriminator="DATABASE_FILE_INTEGRITY"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" config: "DatabaseFileIntegrity.ConfigSchema" diff --git a/docs/source/node_sets.rst b/docs/source/node_sets.rst index 3c247478..1ac6a54c 100644 --- a/docs/source/node_sets.rst +++ b/docs/source/node_sets.rst @@ -82,7 +82,7 @@ Here is an example of creating a custom node adder, DataCenterAdder: .. code-block:: python - class DataCenterAdder(NetworkNodeAdder, identifier="data_center"): + class DataCenterAdder(NetworkNodeAdder, discriminator="data_center"): class ConfigSchema(NetworkNodeAdder.ConfigSchema): type: Literal["data_center"] = "data_center" num_servers: int diff --git a/src/primaite/game/agent/actions/abstract.py b/src/primaite/game/agent/actions/abstract.py index 1cda4360..9e55cf09 100644 --- a/src/primaite/game/agent/actions/abstract.py +++ b/src/primaite/game/agent/actions/abstract.py @@ -22,13 +22,20 @@ class AbstractAction(BaseModel, ABC): _registry: ClassVar[Dict[str, Type[AbstractAction]]] = {} - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: + """ + Register an action type. + + :param discriminator: discriminator used to uniquely specify action types. + :type discriminator: str + :raises ValueError: When attempting to create an action with a name that is already in use. + """ super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - if identifier in cls._registry: - raise ValueError(f"Cannot create new action under reserved name {identifier}") - cls._registry[identifier] = cls + if discriminator in cls._registry: + raise ValueError(f"Cannot create new action under reserved name {discriminator}") + cls._registry[discriminator] = cls @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index 7b70d10d..4ef4b506 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -37,7 +37,7 @@ class ACLAddRuleAbstractAction(AbstractAction, ABC): dst_wildcard: Union[IPV4Address, Literal["NONE"]] -class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"): +class ACLRemoveRuleAbstractAction(AbstractAction, discriminator="acl_remove_rule_abstract_action"): """Base abstract class for ACL remove rule actions.""" config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema" @@ -48,7 +48,7 @@ class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_ab position: int -class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"): +class RouterACLAddRuleAction(ACLAddRuleAbstractAction, discriminator="router_acl_add_rule"): """Action which adds a rule to a router's ACL.""" config: "RouterACLAddRuleAction.ConfigSchema" @@ -79,7 +79,7 @@ class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_ad ] -class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"): +class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, discriminator="router_acl_remove_rule"): """Action which removes a rule from a router's ACL.""" config: "RouterACLRemoveRuleAction.ConfigSchema" @@ -95,7 +95,7 @@ class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_ return ["network", "node", config.target_router, "acl", "remove_rule", config.position] -class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"): +class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, discriminator="firewall_acl_add_rule"): """Action which adds a rule to a firewall port's ACL.""" config: "FirewallACLAddRuleAction.ConfigSchema" @@ -130,7 +130,7 @@ class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_ac ] -class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"): +class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, discriminator="firewall_acl_remove_rule"): """Action which removes a rule from a firewall port's ACL.""" config: "FirewallACLRemoveRuleAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/application.py b/src/primaite/game/agent/actions/application.py index f6ce0624..36d2e0b4 100644 --- a/src/primaite/game/agent/actions/application.py +++ b/src/primaite/game/agent/actions/application.py @@ -45,7 +45,7 @@ class NodeApplicationAbstractAction(AbstractAction, ABC): ] -class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"): +class NodeApplicationExecuteAction(NodeApplicationAbstractAction, discriminator="node_application_execute"): """Action which executes an application.""" config: "NodeApplicationExecuteAction.ConfigSchema" @@ -56,7 +56,7 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="no verb: str = "execute" -class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"): +class NodeApplicationScanAction(NodeApplicationAbstractAction, discriminator="node_application_scan"): """Action which scans an application.""" config: "NodeApplicationScanAction.ConfigSchema" @@ -67,7 +67,7 @@ class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_ verb: str = "scan" -class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"): +class NodeApplicationCloseAction(NodeApplicationAbstractAction, discriminator="node_application_close"): """Action which closes an application.""" config: "NodeApplicationCloseAction.ConfigSchema" @@ -78,7 +78,7 @@ class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node verb: str = "close" -class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"): +class NodeApplicationFixAction(NodeApplicationAbstractAction, discriminator="node_application_fix"): """Action which fixes an application.""" config: "NodeApplicationFixAction.ConfigSchema" @@ -89,7 +89,7 @@ class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_a verb: str = "fix" -class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"): +class NodeApplicationInstallAction(NodeApplicationAbstractAction, discriminator="node_application_install"): """Action which installs an application.""" config: "NodeApplicationInstallAction.ConfigSchema" @@ -113,7 +113,7 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="no ] -class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"): +class NodeApplicationRemoveAction(NodeApplicationAbstractAction, discriminator="node_application_remove"): """Action which removes/uninstalls an application.""" config: "NodeApplicationRemoveAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/file.py b/src/primaite/game/agent/actions/file.py index ed666773..2aa3b85c 100644 --- a/src/primaite/game/agent/actions/file.py +++ b/src/primaite/game/agent/actions/file.py @@ -52,7 +52,7 @@ class NodeFileAbstractAction(AbstractAction, ABC): ] -class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"): +class NodeFileCreateAction(NodeFileAbstractAction, discriminator="node_file_create"): """Action which creates a new file in a given folder.""" config: "NodeFileCreateAction.ConfigSchema" @@ -81,7 +81,7 @@ class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create" ] -class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"): +class NodeFileScanAction(NodeFileAbstractAction, discriminator="node_file_scan"): """Action which scans a file.""" config: "NodeFileScanAction.ConfigSchema" @@ -92,7 +92,7 @@ class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"): verb: ClassVar[str] = "scan" -class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"): +class NodeFileDeleteAction(NodeFileAbstractAction, discriminator="node_file_delete"): """Action which deletes a file.""" config: "NodeFileDeleteAction.ConfigSchema" @@ -119,7 +119,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete" ] -class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"): +class NodeFileRestoreAction(NodeFileAbstractAction, discriminator="node_file_restore"): """Action which restores a file.""" config: "NodeFileRestoreAction.ConfigSchema" @@ -130,7 +130,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restor verb: ClassVar[str] = "restore" -class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"): +class NodeFileCorruptAction(NodeFileAbstractAction, discriminator="node_file_corrupt"): """Action which corrupts a file.""" config: "NodeFileCorruptAction.ConfigSchema" @@ -141,7 +141,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrup verb: ClassVar[str] = "corrupt" -class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"): +class NodeFileAccessAction(NodeFileAbstractAction, discriminator="node_file_access"): """Action which increases a file's access count.""" config: "NodeFileAccessAction.ConfigSchema" @@ -167,7 +167,7 @@ class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access" ] -class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"): +class NodeFileCheckhashAction(NodeFileAbstractAction, discriminator="node_file_checkhash"): """Action which checks the hash of a file.""" config: "NodeFileCheckhashAction.ConfigSchema" @@ -178,7 +178,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_chec verb: ClassVar[str] = "checkhash" -class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"): +class NodeFileRepairAction(NodeFileAbstractAction, discriminator="node_file_repair"): """Action which repairs a file.""" config: "NodeFileRepairAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/folder.py b/src/primaite/game/agent/actions/folder.py index 3e1136ac..c0a03398 100644 --- a/src/primaite/game/agent/actions/folder.py +++ b/src/primaite/game/agent/actions/folder.py @@ -47,7 +47,7 @@ class NodeFolderAbstractAction(AbstractAction, ABC): ] -class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"): +class NodeFolderScanAction(NodeFolderAbstractAction, discriminator="node_folder_scan"): """Action which scans a folder.""" config: "NodeFolderScanAction.ConfigSchema" @@ -58,7 +58,7 @@ class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_sca verb: ClassVar[str] = "scan" -class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"): +class NodeFolderCheckhashAction(NodeFolderAbstractAction, discriminator="node_folder_checkhash"): """Action which checks the hash of a folder.""" config: "NodeFolderCheckhashAction.ConfigSchema" @@ -69,7 +69,7 @@ class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folde verb: ClassVar[str] = "checkhash" -class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"): +class NodeFolderRepairAction(NodeFolderAbstractAction, discriminator="node_folder_repair"): """Action which repairs a folder.""" config: "NodeFolderRepairAction.ConfigSchema" @@ -80,7 +80,7 @@ class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_r verb: ClassVar[str] = "repair" -class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"): +class NodeFolderRestoreAction(NodeFolderAbstractAction, discriminator="node_folder_restore"): """Action which restores a folder.""" config: "NodeFolderRestoreAction.ConfigSchema" @@ -91,7 +91,7 @@ class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_ verb: ClassVar[str] = "restore" -class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"): +class NodeFolderCreateAction(NodeFolderAbstractAction, discriminator="node_folder_create"): """Action which creates a new folder.""" config: "NodeFolderCreateAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/host_nic.py b/src/primaite/game/agent/actions/host_nic.py index 0ca816f3..35599325 100644 --- a/src/primaite/game/agent/actions/host_nic.py +++ b/src/primaite/game/agent/actions/host_nic.py @@ -40,7 +40,7 @@ class HostNICAbstractAction(AbstractAction, ABC): ] -class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"): +class HostNICEnableAction(HostNICAbstractAction, discriminator="host_nic_enable"): """Action which enables a NIC.""" config: "HostNICEnableAction.ConfigSchema" @@ -51,7 +51,7 @@ class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"): verb: ClassVar[str] = "enable" -class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"): +class HostNICDisableAction(HostNICAbstractAction, discriminator="host_nic_disable"): """Action which disables a NIC.""" config: "HostNICDisableAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index 3e5b21b1..8332368e 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -24,7 +24,7 @@ from primaite.interface.request import RequestFormat __all__ = ("DoNothingAction", "ActionManager") -class DoNothingAction(AbstractAction, identifier="do_nothing"): +class DoNothingAction(AbstractAction, discriminator="do_nothing"): """Do Nothing Action.""" class ConfigSchema(AbstractAction.ConfigSchema): diff --git a/src/primaite/game/agent/actions/network.py b/src/primaite/game/agent/actions/network.py index d244fb74..afa22861 100644 --- a/src/primaite/game/agent/actions/network.py +++ b/src/primaite/game/agent/actions/network.py @@ -8,7 +8,7 @@ from primaite.interface.request import RequestFormat __all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction") -class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"): +class NetworkPortAbstractAction(AbstractAction, discriminator="network_port_abstract"): """Base class for Network port actions.""" config: "NetworkPortAbstractAction.ConfigSchema" @@ -35,7 +35,7 @@ class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstrac ] -class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"): +class NetworkPortEnableAction(NetworkPortAbstractAction, discriminator="network_port_enable"): """Action which enables are port on a router or a firewall.""" config: "NetworkPortEnableAction.ConfigSchema" @@ -46,7 +46,7 @@ class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_por verb: ClassVar[str] = "enable" -class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"): +class NetworkPortDisableAction(NetworkPortAbstractAction, discriminator="network_port_disable"): """Action which disables are port on a router or a firewall.""" config: "NetworkPortDisableAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index 5e1b6725..7f7b01a2 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -18,7 +18,7 @@ __all__ = ( ) -class NodeAbstractAction(AbstractAction, identifier="node_abstract"): +class NodeAbstractAction(AbstractAction, discriminator="node_abstract"): """ Abstract base class for node actions. @@ -39,7 +39,7 @@ class NodeAbstractAction(AbstractAction, identifier="node_abstract"): return ["network", "node", config.node_name, config.verb] -class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"): +class NodeOSScanAction(NodeAbstractAction, discriminator="node_os_scan"): """Action which scans a node's OS.""" config: "NodeOSScanAction.ConfigSchema" @@ -50,7 +50,7 @@ class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"): verb: ClassVar[str] = "scan" -class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"): +class NodeShutdownAction(NodeAbstractAction, discriminator="node_shutdown"): """Action which shuts down a node.""" config: "NodeShutdownAction.ConfigSchema" @@ -61,7 +61,7 @@ class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"): verb: ClassVar[str] = "shutdown" -class NodeStartupAction(NodeAbstractAction, identifier="node_startup"): +class NodeStartupAction(NodeAbstractAction, discriminator="node_startup"): """Action which starts up a node.""" config: "NodeStartupAction.ConfigSchema" @@ -72,7 +72,7 @@ class NodeStartupAction(NodeAbstractAction, identifier="node_startup"): verb: ClassVar[str] = "startup" -class NodeResetAction(NodeAbstractAction, identifier="node_reset"): +class NodeResetAction(NodeAbstractAction, discriminator="node_reset"): """Action which resets a node.""" config: "NodeResetAction.ConfigSchema" @@ -83,7 +83,7 @@ class NodeResetAction(NodeAbstractAction, identifier="node_reset"): verb: ClassVar[str] = "reset" -class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"): +class NodeNMAPAbstractAction(AbstractAction, discriminator="node_nmap_abstract_action"): """Base class for NodeNMAP actions.""" config: "NodeNMAPAbstractAction.ConfigSchema" @@ -103,7 +103,7 @@ class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_acti pass -class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"): +class NodeNMAPPingScanAction(NodeNMAPAbstractAction, discriminator="node_nmap_ping_scan"): """Action which performs an NMAP ping scan.""" config: "NodeNMAPPingScanAction.ConfigSchema" @@ -122,7 +122,7 @@ class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_ ] -class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"): +class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node_nmap_port_scan"): """Action which performs an NMAP port scan.""" config: "NodeNMAPPortScanAction.ConfigSchema" @@ -154,7 +154,7 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_ ] -class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"): +class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node_network_service_recon"): """Action which performs an NMAP network service recon (ping scan followed by port scan).""" config: "NodeNetworkServiceReconAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/service.py b/src/primaite/game/agent/actions/service.py index 4a483f28..4adbe139 100644 --- a/src/primaite/game/agent/actions/service.py +++ b/src/primaite/game/agent/actions/service.py @@ -17,7 +17,7 @@ __all__ = ( ) -class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"): +class NodeServiceAbstractAction(AbstractAction, discriminator="node_service_abstract"): """Abstract Action for Node Service related actions. Any actions which use node_name and service_name can inherit from this class. @@ -36,7 +36,7 @@ class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstrac return ["network", "node", config.node_name, "service", config.service_name, config.verb] -class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"): +class NodeServiceScanAction(NodeServiceAbstractAction, discriminator="node_service_scan"): """Action which scans a service.""" config: "NodeServiceScanAction.ConfigSchema" @@ -47,7 +47,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_ verb: ClassVar[str] = "scan" -class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"): +class NodeServiceStopAction(NodeServiceAbstractAction, discriminator="node_service_stop"): """Action which stops a service.""" config: "NodeServiceStopAction.ConfigSchema" @@ -58,7 +58,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_ verb: ClassVar[str] = "stop" -class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"): +class NodeServiceStartAction(NodeServiceAbstractAction, discriminator="node_service_start"): """Action which starts a service.""" config: "NodeServiceStartAction.ConfigSchema" @@ -69,7 +69,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service verb: ClassVar[str] = "start" -class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"): +class NodeServicePauseAction(NodeServiceAbstractAction, discriminator="node_service_pause"): """Action which pauses a service.""" config: "NodeServicePauseAction.ConfigSchema" @@ -80,7 +80,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service verb: ClassVar[str] = "pause" -class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"): +class NodeServiceResumeAction(NodeServiceAbstractAction, discriminator="node_service_resume"): """Action which resumes a service.""" config: "NodeServiceResumeAction.ConfigSchema" @@ -91,7 +91,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_servic verb: ClassVar[str] = "resume" -class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"): +class NodeServiceRestartAction(NodeServiceAbstractAction, discriminator="node_service_restart"): """Action which restarts a service.""" config: "NodeServiceRestartAction.ConfigSchema" @@ -102,7 +102,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_servi verb: ClassVar[str] = "restart" -class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"): +class NodeServiceDisableAction(NodeServiceAbstractAction, discriminator="node_service_disable"): """Action which disables a service.""" config: "NodeServiceDisableAction.ConfigSchema" @@ -113,7 +113,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_servi verb: ClassVar[str] = "disable" -class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"): +class NodeServiceEnableAction(NodeServiceAbstractAction, discriminator="node_service_enable"): """Action which enables a service.""" config: "NodeServiceEnableAction.ConfigSchema" @@ -124,7 +124,7 @@ class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_servic verb: ClassVar[str] = "enable" -class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"): +class NodeServiceFixAction(NodeServiceAbstractAction, discriminator="node_service_fix"): """Action which fixes a service.""" config: "NodeServiceFixAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index 9720d371..4bed1943 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -11,7 +11,7 @@ __all__ = ( ) -class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"): +class NodeSessionAbstractAction(AbstractAction, discriminator="node_session_abstract"): """Base class for NodeSession actions.""" config: "NodeSessionAbstractAction.ConfigSchema" @@ -33,7 +33,7 @@ class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstrac pass -class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"): +class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node_session_remote_login"): """Action which performs a remote session login.""" config: "NodeSessionsRemoteLoginAction.ConfigSchema" @@ -62,7 +62,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_ ] -class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"): +class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="node_session_remote_logoff"): """Action which performs a remote session logout.""" config: "NodeSessionsRemoteLogoutAction.ConfigSchema" @@ -80,7 +80,7 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip] -class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"): +class NodeAccountChangePasswordAction(NodeSessionAbstractAction, discriminator="node_account_change_password"): """Action which changes the password for a user.""" config: "NodeAccountChangePasswordAction.ConfigSchema" diff --git a/src/primaite/game/agent/actions/software.py b/src/primaite/game/agent/actions/software.py index 81a3a315..49edd7c5 100644 --- a/src/primaite/game/agent/actions/software.py +++ b/src/primaite/game/agent/actions/software.py @@ -22,7 +22,7 @@ __all__ = ( ) -class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_ransomware_script"): +class ConfigureRansomwareScriptAction(AbstractAction, discriminator="configure_ransomware_script"): """Action which sets config parameters for a ransomware script on a node.""" config: "ConfigureRansomwareScriptAction.ConfigSchema" @@ -48,7 +48,9 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_rans return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", data] -class RansomwareConfigureC2ServerAction(ConfigureRansomwareScriptAction, identifier="c2_server_ransomware_configure"): +class RansomwareConfigureC2ServerAction( + ConfigureRansomwareScriptAction, discriminator="c2_server_ransomware_configure" +): """Action which causes a C2 server to send a command to set options on a ransomware script remotely.""" @classmethod @@ -59,7 +61,7 @@ class RansomwareConfigureC2ServerAction(ConfigureRansomwareScriptAction, identif return ["network", "node", config.node_name, "application", "C2Server", "ransomware_configure", data] -class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): +class ConfigureDoSBotAction(AbstractAction, discriminator="configure_dos_bot"): """Action which sets config parameters for a DoS bot on a node.""" class ConfigSchema(AbstractAction.ConfigSchema): @@ -91,7 +93,7 @@ class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): return ["network", "node", config.node_name, "application", "DoSBot", "configure", data] -class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): +class ConfigureC2BeaconAction(AbstractAction, discriminator="configure_c2_beacon"): """Action which configures a C2 Beacon based on the parameters given.""" class ConfigSchema(AbstractAction.ConfigSchema): @@ -115,7 +117,7 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): return ["network", "node", config.node_name, "application", "C2Beacon", "configure", data] -class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"): +class NodeSendRemoteCommandAction(AbstractAction, discriminator="node_send_remote_command"): """Action which sends a terminal command to a remote node via SSH.""" config: "NodeSendRemoteCommandAction.ConfigSchema" @@ -142,7 +144,7 @@ class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_c ] -class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"): +class TerminalC2ServerAction(AbstractAction, discriminator="c2_server_terminal_command"): """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" config: "TerminalC2ServerAction.ConfigSchema" @@ -171,7 +173,7 @@ class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_comm return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model] -class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"): +class RansomwareLaunchC2ServerAction(AbstractAction, discriminator="c2_server_ransomware_launch"): """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" config: "RansomwareLaunchC2ServerAction.ConfigSchema" @@ -190,7 +192,7 @@ class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ranso return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"] -class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"): +class ExfiltrationC2ServerAction(AbstractAction, discriminator="c2_server_data_exfiltrate"): """Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server.""" config: "ExfiltrationC2ServerAction.ConfigSchema" @@ -223,7 +225,7 @@ class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfi return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model] -class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"): +class ConfigureDatabaseClientAction(AbstractAction, discriminator="configure_database_client"): """Action which sets config parameters for a database client on a node.""" config: "ConfigureDatabaseClientAction.ConfigSchema" diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index aac898e1..cb1c15dd 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -79,13 +79,13 @@ class AbstractAgent(BaseModel, ABC): _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - if identifier in cls._registry: - raise ValueError(f"Cannot create a new agent under reserved name {identifier}") - cls._registry[identifier] = cls + if discriminator in cls._registry: + raise ValueError(f"Cannot create a new agent under reserved name {discriminator}") + cls._registry[discriminator] = cls def model_post_init(self, __context: Any) -> None: """Overwrite the default empty action, observation, and rewards with ones defined through the config.""" @@ -161,7 +161,7 @@ class AbstractAgent(BaseModel, ABC): return agent_class(config=config) -class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"): +class AbstractScriptedAgent(AbstractAgent, discriminator="AbstractScriptedAgent"): """Base class for actors which generate their own behaviour.""" config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema()) @@ -177,7 +177,7 @@ class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"): return super().get_action(obs=obs, timestep=timestep) -class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): +class ProxyAgent(AbstractAgent, discriminator="ProxyAgent"): """Agent that sends observations to an RL model and receives actions from that model.""" config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema()) diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index fde49a6b..ef171431 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -16,7 +16,7 @@ from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class ACLObservation(AbstractObservation, identifier="ACL"): +class ACLObservation(AbstractObservation, discriminator="ACL"): """ACL observation, provides information about access control lists within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 784eaa7f..82ae9acc 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -13,7 +13,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) -class FileObservation(AbstractObservation, identifier="FILE"): +class FileObservation(AbstractObservation, discriminator="FILE"): """File observation, provides status information about a file within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -125,7 +125,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): ) -class FolderObservation(AbstractObservation, identifier="FOLDER"): +class FolderObservation(AbstractObservation, discriminator="FOLDER"): """Folder observation, provides status information about a folder within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 6e5fffb9..f0390697 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -18,7 +18,7 @@ from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class FirewallObservation(AbstractObservation, identifier="FIREWALL"): +class FirewallObservation(AbstractObservation, discriminator="FIREWALL"): """Firewall observation, provides status information about a firewall within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index e46cc805..ed4dd21c 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -18,7 +18,7 @@ from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class HostObservation(AbstractObservation, identifier="HOST"): +class HostObservation(AbstractObservation, discriminator="HOST"): """Host observation, provides status information about a host within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/link_observation.py b/src/primaite/game/agent/observations/link_observation.py index 851e9557..303e421c 100644 --- a/src/primaite/game/agent/observations/link_observation.py +++ b/src/primaite/game/agent/observations/link_observation.py @@ -13,7 +13,7 @@ from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_ST _LOGGER = getLogger(__name__) -class LinkObservation(AbstractObservation, identifier="LINK"): +class LinkObservation(AbstractObservation, discriminator="LINK"): """Link observation, providing information about a specific link within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -90,7 +90,7 @@ class LinkObservation(AbstractObservation, identifier="LINK"): return cls(where=where) -class LinksObservation(AbstractObservation, identifier="LINKS"): +class LinksObservation(AbstractObservation, discriminator="LINKS"): """Collection of link observations representing multiple links within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index f87d2d76..4c8fbaf5 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -12,7 +12,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.port import Port -class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): +class NICObservation(AbstractObservation, discriminator="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -227,7 +227,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): ) -class PortObservation(AbstractObservation, identifier="PORT"): +class PortObservation(AbstractObservation, discriminator="PORT"): """Port observation, provides status information about a network port within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 1a0f48b4..2937aa7c 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -19,7 +19,7 @@ from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class NodesObservation(AbstractObservation, identifier="NODES"): +class NodesObservation(AbstractObservation, discriminator="NODES"): """Nodes observation, provides status information about nodes within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 83d4a076..0d28aa98 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, computed_field, ConfigDict, Field, model_validat from primaite.game.agent.observations.observations import AbstractObservation, WhereType -class NestedObservation(AbstractObservation, identifier="CUSTOM"): +class NestedObservation(AbstractObservation, discriminator="CUSTOM"): """Observation type that allows combining other observations into a gymnasium.spaces.Dict space.""" class NestedObservationItem(BaseModel): @@ -19,7 +19,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): model_config = ConfigDict(extra="forbid") type: str - """Select observation class. It maps to the identifier of the obs class by checking the registry.""" + """Select observation class. It maps to the discriminator of the obs class by checking the registry.""" label: str """Dict key in the final observation space.""" options: Dict @@ -119,7 +119,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): return cls(components=instances) -class NullObservation(AbstractObservation, identifier="NONE"): +class NullObservation(AbstractObservation, discriminator="NONE"): """Empty observation that acts as a placeholder.""" def __init__(self) -> None: @@ -158,7 +158,7 @@ class ObservationManager(BaseModel): model_config = ConfigDict(extra="forbid") type: str = "NONE" - """Identifier name for the top-level observation.""" + """discriminator name for the top-level observation.""" options: AbstractObservation.ConfigSchema = Field( default_factory=lambda: NullObservation.ConfigSchema(), validate_default=True ) @@ -235,7 +235,7 @@ class ObservationManager(BaseModel): :param config: Dictionary containing the configuration for this observation space. If None, a blank observation space is created. Otherwise, this must be a Dict with a type field and options field. - type: string that corresponds to one of the observation identifiers that are provided when subclassing + type: string that corresponds to one of the observation discriminators that are provided when subclassing AbstractObservation options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 89c45b37..da81d2ad 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -31,20 +31,20 @@ class AbstractObservation(ABC): """Initialise an observation. This method must be overwritten.""" self.default_observation: ObsType - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register an observation type. - :param identifier: Identifier used to uniquely specify observation component types. - :type identifier: str + :param discriminator: discriminator used to uniquely specify observation component types. + :type discriminator: str :raises ValueError: When attempting to create a component with a name that is already in use. """ super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - if identifier in cls._registry: - raise ValueError(f"Duplicate observation component type {identifier}") - cls._registry[identifier] = cls + if discriminator in cls._registry: + raise ValueError(f"Duplicate observation component type {discriminator}") + cls._registry[discriminator] = cls @abstractmethod def observe(self, state: Dict) -> Any: diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index ab759779..8eaad1b1 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -18,7 +18,7 @@ from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) -class RouterObservation(AbstractObservation, identifier="ROUTER"): +class RouterObservation(AbstractObservation, discriminator="ROUTER"): """Router observation, provides status information about a router within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 37810c6e..6e2fbb73 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -10,7 +10,7 @@ from primaite.game.agent.observations.observations import AbstractObservation, W from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -class ServiceObservation(AbstractObservation, identifier="SERVICE"): +class ServiceObservation(AbstractObservation, discriminator="SERVICE"): """Service observation, shows status of a service in the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): @@ -73,7 +73,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): return cls(where=parent_where + ["services", config.service_name]) -class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): +class ApplicationObservation(AbstractObservation, discriminator="APPLICATION"): """Application observation, shows the status of an application within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 80be14ef..3e961bdf 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -55,13 +55,13 @@ class AbstractReward(BaseModel): _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - if identifier in cls._registry: - raise ValueError(f"Duplicate reward {identifier}") - cls._registry[identifier] = cls + if discriminator in cls._registry: + raise ValueError(f"Duplicate reward {discriminator}") + cls._registry[discriminator] = cls @classmethod def from_config(cls, config: Dict) -> "AbstractReward": @@ -92,7 +92,7 @@ class AbstractReward(BaseModel): return 0.0 -class DummyReward(AbstractReward, identifier="DUMMY"): +class DummyReward(AbstractReward, discriminator="DUMMY"): """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -108,7 +108,7 @@ class DummyReward(AbstractReward, identifier="DUMMY"): return 0.0 -class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): +class DatabaseFileIntegrity(AbstractReward, discriminator="DATABASE_FILE_INTEGRITY"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" config: "DatabaseFileIntegrity.ConfigSchema" @@ -161,7 +161,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY" return 0 -class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): +class WebServer404Penalty(AbstractReward, discriminator="WEB_SERVER_404_PENALTY"): """Reward function component which penalises the agent when the web server returns a 404 error.""" config: "WebServer404Penalty.ConfigSchema" @@ -215,7 +215,7 @@ class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): return self.reward -class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"): +class WebpageUnavailablePenalty(AbstractReward, discriminator="WEBPAGE_UNAVAILABLE_PENALTY"): """Penalises the agent when the web browser fails to fetch a webpage.""" config: "WebpageUnavailablePenalty.ConfigSchema" @@ -289,7 +289,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_ return self.reward -class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, discriminator="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"): """Penalises the agent when the green db clients fail to connect to the database.""" config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" @@ -339,7 +339,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADM return self.reward -class SharedReward(AbstractReward, identifier="SHARED_REWARD"): +class SharedReward(AbstractReward, discriminator="SHARED_REWARD"): """Adds another agent's reward to the overall reward.""" config: "SharedReward.ConfigSchema" @@ -376,7 +376,7 @@ class SharedReward(AbstractReward, identifier="SHARED_REWARD"): return self.callback(self.config.agent_name) -class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): +class ActionPenalty(AbstractReward, discriminator="ACTION_PENALTY"): """Apply a negative reward when taking any action except do_nothing.""" config: "ActionPenalty.ConfigSchema" diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py index e6ddd546..f36c93de 100644 --- a/src/primaite/game/agent/scripted_agents/abstract_tap.py +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -13,7 +13,7 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent __all__ = "AbstractTAPAgent" -class AbstractTAPAgent(PeriodicAgent, identifier="AbstractTAP"): +class AbstractTAPAgent(PeriodicAgent, discriminator="AbstractTAP"): """Base class for TAP agents to inherit from.""" config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema()) diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index a7558d42..b32df428 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -9,7 +9,7 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent __all__ = "DataManipulationAgent" -class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgent"): +class DataManipulationAgent(PeriodicAgent, discriminator="RedDatabaseCorruptingAgent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema): diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index de643ed8..2ddc39b7 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -13,7 +13,7 @@ from primaite.game.agent.interface import AbstractScriptedAgent __all__ = "ProbabilisticAgent" -class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"): +class ProbabilisticAgent(AbstractScriptedAgent, discriminator="ProbabilisticAgent"): """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" rng: Generator = Field(default_factory=lambda: np.random.default_rng(np.random.randint(0, 65535))) diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index 9cf8e798..3d652dfc 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -11,7 +11,7 @@ from primaite.game.agent.interface import AbstractScriptedAgent __all__ = ("RandomAgent", "PeriodicAgent") -class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"): +class RandomAgent(AbstractScriptedAgent, discriminator="RandomAgent"): """Agent that ignores its observation and acts completely at random.""" config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema()) @@ -34,7 +34,7 @@ class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"): return self.action_manager.get_action(self.action_manager.space.sample()) -class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): +class PeriodicAgent(AbstractScriptedAgent, discriminator="PeriodicAgent"): """Agent that does nothing most of the time, but executes application at regular intervals (with variance).""" config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema()) diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 7af8b98e..690e7856 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -172,7 +172,7 @@ "\n", "\n", "# no applications exist yet so we will create our own.\n", - "class MSPaint(Application, identifier=\"MSPaint\"):\n", + "class MSPaint(Application, discriminator=\"MSPaint\"):\n", " def describe_state(self):\n", " return super().describe_state()" ] diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 567a0493..7ccd202e 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -244,7 +244,7 @@ class SimComponent(BaseModel): ..code::python - class WebBrowser(Application, identifier="WebBrowser"): + class WebBrowser(Application, discriminator="WebBrowser"): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() # all requests generic to any Application get initialised rm.add_request(...) # initialise any requests specific to the web browser diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index bf677d5c..f5ae0232 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -180,7 +180,7 @@ class Network(SimComponent): table.align = "l" table.title = "Nodes" for node in self.nodes.values(): - table.add_row((node.hostname, type(node)._identifier, node.operating_state.name)) + table.add_row((node.hostname, type(node)._discriminator, node.operating_state.name)) print(table) if ip_addresses: diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 2cf8774e..e16a7fcc 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -22,7 +22,7 @@ class NetworkNodeAdder(BaseModel): Here is a template that users can use to define custom node adders: ``` - class YourNodeAdder(NetworkNodeAdder, identifier="your_name"): + class YourNodeAdder(NetworkNodeAdder, discriminator="your_name"): class ConfigSchema(NetworkNodeAdder.ConfigSchema): property_1 : str property_2 : int @@ -40,8 +40,8 @@ class NetworkNodeAdder(BaseModel): """ Base schema for node adders. - Child classes of NetworkNodeAdder must define a schema which inherits from this schema. The identifier is used - by the from_config method to select the correct node adder at runtime. + Child classes of NetworkNodeAdder must define a schema which inherits from this schema. The discriminator is + used by the from_config method to select the correct node adder at runtime. """ model_config = ConfigDict(extra="forbid") @@ -50,20 +50,20 @@ class NetworkNodeAdder(BaseModel): _registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {} - def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str], **kwargs: Any) -> None: """ Register a network node adder class. - :param identifier: Unique name for the node adder to use for matching against primaite config entries. - :type identifier: str + :param discriminator: Unique name for the node adder to use for matching against primaite config entries. + :type discriminator: str :raises ValueError: When attempting to register a name that is already reserved. """ super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - if identifier in cls._registry: - raise ValueError(f"Duplicate node adder {identifier}") - cls._registry[identifier] = cls + if discriminator in cls._registry: + raise ValueError(f"Duplicate node adder {discriminator}") + cls._registry[discriminator] = cls @classmethod @abstractmethod @@ -99,7 +99,7 @@ class NetworkNodeAdder(BaseModel): adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config), network=network) -class OfficeLANAdder(NetworkNodeAdder, identifier="office_lan"): +class OfficeLANAdder(NetworkNodeAdder, discriminator="office_lan"): """Creates an office LAN.""" class ConfigSchema(NetworkNodeAdder.ConfigSchema): diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index ecbd0629..bacba15b 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -824,7 +824,7 @@ class User(SimComponent): return self.model_dump() -class UserManager(Service, identifier="UserManager"): +class UserManager(Service, discriminator="UserManager"): """ Manages users within the PrimAITE system, handling creation, authentication, and administration. @@ -1137,7 +1137,7 @@ class RemoteUserSession(UserSession): return state -class UserSessionManager(Service, identifier="UserSessionManager"): +class UserSessionManager(Service, discriminator="UserSessionManager"): """ Manages user sessions on a Node, including local and remote sessions. @@ -1483,7 +1483,7 @@ class UserSessionManager(Service, identifier="UserSessionManager"): return self.local_session is not None -class Node(SimComponent): +class Node(SimComponent, ABC): """ A basic Node class that represents a node on the network. @@ -1556,25 +1556,26 @@ class Node(SimComponent): _registry: ClassVar[Dict[str, Type["Node"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - _identifier: ClassVar[str] = "unknown" - """Identifier for this particular class, used for printing and logging. Each subclass redefines this.""" + # TODO: this should not be set for abstract classes. + _discriminator: ClassVar[str] + """discriminator for this particular class, used for printing and logging. Each subclass redefines this.""" - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register a node type. - :param identifier: Uniquely specifies an node class by name. Used for finding items by config. - :type identifier: str + :param discriminator: Uniquely specifies an node class by name. Used for finding items by config. + :type discriminator: str :raises ValueError: When attempting to register an node with a name that is already allocated. """ super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - identifier = identifier.lower() - if identifier in cls._registry: - raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.") - cls._registry[identifier] = cls - cls._identifier = identifier + discriminator = discriminator.lower() + if discriminator in cls._registry: + raise ValueError(f"Tried to define new node {discriminator}, but this name is already reserved.") + cls._registry[discriminator] = cls + cls._discriminator = discriminator def __init__(self, **kwargs): """ diff --git a/src/primaite/simulator/network/hardware/nodes/host/computer.py b/src/primaite/simulator/network/hardware/nodes/host/computer.py index 11b925b9..a47af2ad 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/computer.py +++ b/src/primaite/simulator/network/hardware/nodes/host/computer.py @@ -5,7 +5,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.ftp.ftp_client import FTPClient -class Computer(HostNode, identifier="computer"): +class Computer(HostNode, discriminator="computer"): """ A basic Computer class. diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index c51afbca..f8786a08 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -262,7 +262,7 @@ class NIC(IPWiredNetworkInterface): return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" -class HostNode(Node, identifier="HostNode"): +class HostNode(Node, discriminator="HostNode"): """ Represents a host node in the network. diff --git a/src/primaite/simulator/network/hardware/nodes/host/server.py b/src/primaite/simulator/network/hardware/nodes/host/server.py index e16cfd8f..50b82122 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/server.py +++ b/src/primaite/simulator/network/hardware/nodes/host/server.py @@ -2,7 +2,7 @@ from primaite.simulator.network.hardware.nodes.host.host_node import HostNode -class Server(HostNode, identifier="server"): +class Server(HostNode, discriminator="server"): """ A basic Server class. @@ -31,7 +31,7 @@ class Server(HostNode, identifier="server"): """ -class Printer(HostNode, identifier="printer"): +class Printer(HostNode, discriminator="printer"): """Printer? I don't even know her!.""" # TODO: Implement printer-specific behaviour diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index ac7c12e3..4da9e24c 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -27,7 +27,7 @@ DMZ_PORT_ID: Final[int] = 3 """The Firewall port ID of the DMZ port.""" -class Firewall(Router, identifier="firewall"): +class Firewall(Router, discriminator="firewall"): """ A Firewall class that extends the functionality of a Router. diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index 22ff2b28..185b6bae 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -7,7 +7,7 @@ from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.services.arp.arp import ARP -class NetworkNode(Node, identifier="NetworkNode"): +class NetworkNode(Node, discriminator="NetworkNode"): """ Represents an abstract base class for a network node that can receive and process network frames. diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 4a049f99..b6004e8e 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1184,7 +1184,7 @@ class RouterSessionManager(SessionManager): return outbound_network_interface, dst_mac_address, dst_ip_address, src_port, dst_port, protocol, is_broadcast -class Router(NetworkNode, identifier="router"): +class Router(NetworkNode, discriminator="router"): """ Represents a network router, managing routing and forwarding of IP packets across network interfaces. diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index db923f1a..f06337aa 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -87,7 +87,7 @@ class SwitchPort(WiredNetworkInterface): return False -class Switch(NetworkNode, identifier="switch"): +class Switch(NetworkNode, discriminator="switch"): """ A class representing a Layer 2 network switch. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 804a570e..87408670 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -91,7 +91,7 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): ) -class WirelessRouter(Router, identifier="wireless_router"): +class WirelessRouter(Router, discriminator="wireless_router"): """ A WirelessRouter class that extends the functionality of a standard Router to include wireless capabilities. diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 05a47d7a..1de29c33 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -53,20 +53,20 @@ class Application(IOSoftware, ABC): _registry: ClassVar[Dict[str, Type["Application"]]] = {} """Registry of application types. Automatically populated when subclasses are defined.""" - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register an application type. - :param identifier: Uniquely specifies an application class by name. Used for finding items by config. - :type identifier: Optional[str] + :param discriminator: Uniquely specifies an application class by name. Used for finding items by config. + :type discriminator: Optional[str] :raises ValueError: When attempting to register an application with a name that is already allocated. """ super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return - if identifier in cls._registry: - raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.") - cls._registry[identifier] = cls + if discriminator in cls._registry: + raise ValueError(f"Tried to define new application {discriminator}, but this name is already reserved.") + cls._registry[discriminator] = cls @classmethod def from_config(cls, config: Dict) -> "Application": diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 96130e16..67749e21 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -61,7 +61,7 @@ class DatabaseClientConnection(BaseModel): return str(self) -class DatabaseClient(Application, identifier="DatabaseClient"): +class DatabaseClient(Application, discriminator="DatabaseClient"): """ A DatabaseClient application. diff --git a/src/primaite/simulator/system/applications/nmap.py b/src/primaite/simulator/system/applications/nmap.py index 3eeda4b6..6a29aedf 100644 --- a/src/primaite/simulator/system/applications/nmap.py +++ b/src/primaite/simulator/system/applications/nmap.py @@ -44,7 +44,7 @@ class PortScanPayload(SimComponent): return state -class NMAP(Application, identifier="NMAP"): +class NMAP(Application, discriminator="NMAP"): """ A class representing the NMAP application for network scanning. diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py index b989671e..14e446a4 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_beacon.py @@ -17,7 +17,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import Port, PORT_LOOKUP -class C2Beacon(AbstractC2, identifier="C2Beacon"): +class C2Beacon(AbstractC2, discriminator="C2Beacon"): """ C2 Beacon Application. diff --git a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py index 9d2097e9..df4c34a8 100644 --- a/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py +++ b/src/primaite/simulator/system/applications/red_applications/c2/c2_server.py @@ -16,7 +16,7 @@ from primaite.simulator.system.applications.red_applications.c2 import ( from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload -class C2Server(AbstractC2, identifier="C2Server"): +class C2Server(AbstractC2, discriminator="C2Server"): """ C2 Server Application. diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 392cdfba..7ad31e3b 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -40,7 +40,7 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(Application, identifier="DataManipulationBot"): +class DataManipulationBot(Application, discriminator="DataManipulationBot"): """A bot that simulates a script which performs a SQL injection attack.""" class ConfigSchema(Application.ConfigSchema): diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index a6cb2b75..6153c2a5 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -32,7 +32,7 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient, identifier="DoSBot"): +class DoSBot(DatabaseClient, discriminator="DoSBot"): """A bot that simulates a Denial of Service attack.""" class ConfigSchema(DatabaseClient.ConfigSchema): diff --git a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py index 114d5716..0a818a85 100644 --- a/src/primaite/simulator/system/applications/red_applications/ransomware_script.py +++ b/src/primaite/simulator/system/applications/red_applications/ransomware_script.py @@ -14,7 +14,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP -class RansomwareScript(Application, identifier="RansomwareScript"): +class RansomwareScript(Application, discriminator="RansomwareScript"): """Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network. :ivar payload: The attack stage query payload. (Default ENCRYPT) diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 49f303b5..3eb18f7f 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -23,7 +23,7 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class WebBrowser(Application, identifier="WebBrowser"): +class WebBrowser(Application, discriminator="WebBrowser"): """ Represents a web browser in the simulation environment. diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index bbeec301..311f7e25 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -15,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address from primaite.utils.validation.port import PORT_LOOKUP -class ARP(Service, identifier="ARP"): +class ARP(Service, discriminator="ARP"): """ The ARP (Address Resolution Protocol) Service. diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 1745b9d1..369905db 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -19,7 +19,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class DatabaseService(Service, identifier="DatabaseService"): +class DatabaseService(Service, discriminator="DatabaseService"): """ A class for simulating a generic SQL Server service. diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 825896e0..83a14033 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: _LOGGER = getLogger(__name__) -class DNSClient(Service, identifier="DNSClient"): +class DNSClient(Service, discriminator="DNSClient"): """Represents a DNS Client as a Service.""" class ConfigSchema(Service.ConfigSchema): diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 41a5b25f..ef19a13e 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -14,7 +14,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class DNSServer(Service, identifier="DNSServer"): +class DNSServer(Service, discriminator="DNSServer"): """Represents a DNS Server as a Service.""" class ConfigSchema(Service.ConfigSchema): diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 82875b97..23b55330 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -18,7 +18,7 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class FTPClient(FTPServiceABC, identifier="FTPClient"): +class FTPClient(FTPServiceABC, discriminator="FTPClient"): """ A class for simulating an FTP client service. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 5f4ac846..43184684 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -12,7 +12,7 @@ from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class FTPServer(FTPServiceABC, identifier="FTPServer"): +class FTPServer(FTPServiceABC, discriminator="FTPServer"): """ A class for simulating an FTP server service. diff --git a/src/primaite/simulator/system/services/icmp/icmp.py b/src/primaite/simulator/system/services/icmp/icmp.py index 7f626945..77dbd5be 100644 --- a/src/primaite/simulator/system/services/icmp/icmp.py +++ b/src/primaite/simulator/system/services/icmp/icmp.py @@ -16,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ICMP(Service, identifier="ICMP"): +class ICMP(Service, discriminator="ICMP"): """ The Internet Control Message Protocol (ICMP) service. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index b5f921c9..e3af43f7 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -15,7 +15,7 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class NTPClient(Service, identifier="NTPClient"): +class NTPClient(Service, discriminator="NTPClient"): """Represents a NTP client as a service.""" class ConfigSchema(Service.ConfigSchema): diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 7af33893..b2d8356c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -13,7 +13,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class NTPServer(Service, identifier="NTPServer"): +class NTPServer(Service, discriminator="NTPServer"): """Represents a NTP server as a service.""" class ConfigSchema(Service.ConfigSchema): diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index c4a73301..a7b8fd09 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -61,22 +61,22 @@ class Service(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None: + def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None: """ Register a hostnode type. - :param identifier: Uniquely specifies an hostnode class by name. Used for finding items by config. - :type identifier: str + :param discriminator: Uniquely specifies an hostnode class by name. Used for finding items by config. + :type discriminator: str :raises ValueError: When attempting to register an hostnode with a name that is already allocated. """ super().__init_subclass__(**kwargs) - if identifier is None: + if discriminator is None: return # Enforce lowercase registry entries because it makes comparisons everywhere else much easier. - identifier = identifier.lower() - if identifier in cls._registry: - raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.") - cls._registry[identifier] = cls + discriminator = discriminator.lower() + if discriminator in cls._registry: + raise ValueError(f"Tried to define new hostnode {discriminator}, but this name is already reserved.") + cls._registry[discriminator] = cls @classmethod def from_config(cls, config: Dict) -> "Service": diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index bda8bad3..01d9095b 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -129,7 +129,7 @@ class RemoteTerminalConnection(TerminalClientConnection): return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id) -class Terminal(Service, identifier="Terminal"): +class Terminal(Service, discriminator="Terminal"): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" class ConfigSchema(Service.ConfigSchema): diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 51724002..40a713a5 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -21,7 +21,7 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) -class WebServer(Service, identifier="WebServer"): +class WebServer(Service, discriminator="WebServer"): """Class used to represent a Web Server Service in simulation.""" class ConfigSchema(Service.ConfigSchema): diff --git a/tests/conftest.py b/tests/conftest.py index 165ab30e..70443042 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,7 +39,7 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) -class DummyService(Service, identifier="DummyService"): +class DummyService(Service, discriminator="DummyService"): """Test Service class""" class ConfigSchema(Service.ConfigSchema): @@ -62,7 +62,7 @@ class DummyService(Service, identifier="DummyService"): pass -class DummyApplication(Application, identifier="DummyApplication"): +class DummyApplication(Application, discriminator="DummyApplication"): """Test Application class""" class ConfigSchema(Application.ConfigSchema): @@ -280,7 +280,7 @@ def example_network() -> Network: return network -class ControlledAgent(AbstractAgent, identifier="ControlledAgent"): +class ControlledAgent(AbstractAgent, discriminator="ControlledAgent"): """Agent that can be controlled by the tests.""" config: "ControlledAgent.ConfigSchema" = Field(default_factory=lambda: ControlledAgent.ConfigSchema()) diff --git a/tests/integration_tests/extensions/applications/extended_application.py b/tests/integration_tests/extensions/applications/extended_application.py index 159cfd06..fd6fea3f 100644 --- a/tests/integration_tests/extensions/applications/extended_application.py +++ b/tests/integration_tests/extensions/applications/extended_application.py @@ -24,7 +24,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ExtendedApplication(Application, identifier="ExtendedApplication"): +class ExtendedApplication(Application, discriminator="ExtendedApplication"): """ Clone of web browser that uses the extension framework instead of being part of PrimAITE directly. diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py index 37a05b6e..86da0610 100644 --- a/tests/integration_tests/extensions/nodes/giga_switch.py +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -11,7 +11,7 @@ from primaite.simulator.network.hardware.nodes.network.switch import SwitchPort from primaite.simulator.network.transmission.data_link_layer import Frame -class GigaSwitch(NetworkNode, identifier="gigaswitch"): +class GigaSwitch(NetworkNode, discriminator="gigaswitch"): """ A class representing a Layer 2 network switch. diff --git a/tests/integration_tests/extensions/nodes/super_computer.py b/tests/integration_tests/extensions/nodes/super_computer.py index 4af1b748..99c5fdf5 100644 --- a/tests/integration_tests/extensions/nodes/super_computer.py +++ b/tests/integration_tests/extensions/nodes/super_computer.py @@ -6,7 +6,7 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.utils.validation.ipv4_address import IPV4Address -class SuperComputer(HostNode, identifier="supercomputer"): +class SuperComputer(HostNode, discriminator="supercomputer"): """ A basic Computer class. diff --git a/tests/integration_tests/extensions/services/extended_service.py b/tests/integration_tests/extensions/services/extended_service.py index ba247369..79821b6c 100644 --- a/tests/integration_tests/extensions/services/extended_service.py +++ b/tests/integration_tests/extensions/services/extended_service.py @@ -19,7 +19,7 @@ from primaite.utils.validation.port import PORT_LOOKUP _LOGGER = getLogger(__name__) -class ExtendedService(Service, identifier="ExtendedService"): +class ExtendedService(Service, discriminator="ExtendedService"): """ A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE. diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index ed40334f..d2ec06ae 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -15,7 +15,7 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP -class BroadcastTestService(Service, identifier="BroadcastTestService"): +class BroadcastTestService(Service, discriminator="BroadcastTestService"): """A service for sending broadcast and unicast messages over a network.""" class ConfigSchema(Service.ConfigSchema): @@ -51,7 +51,7 @@ class BroadcastTestService(Service, identifier="BroadcastTestService"): ) -class BroadcastTestClient(Application, identifier="BroadcastTestClient"): +class BroadcastTestClient(Application, discriminator="BroadcastTestClient"): """A client application to receive broadcast and unicast messages.""" class ConfigSchema(Service.ConfigSchema): diff --git a/tests/integration_tests/system/test_service_listening_on_ports.py b/tests/integration_tests/system/test_service_listening_on_ports.py index 84413ac9..db5381d0 100644 --- a/tests/integration_tests/system/test_service_listening_on_ports.py +++ b/tests/integration_tests/system/test_service_listening_on_ports.py @@ -14,7 +14,7 @@ from primaite.utils.validation.port import PORT_LOOKUP from tests import TEST_ASSETS_ROOT -class _DatabaseListener(Service, identifier="_DatabaseListener"): +class _DatabaseListener(Service, discriminator="_DatabaseListener"): class ConfigSchema(Service.ConfigSchema): """ConfigSchema for _DatabaseListener.""" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py index 16a4c9ad..9e448b87 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_application_registry.py @@ -5,14 +5,14 @@ from primaite.simulator.system.applications.application import Application def test_adding_to_app_registry(): - class temp_application(Application, identifier="temp_app"): + class temp_application(Application, discriminator="temp_app"): pass assert Application._registry["temp_app"] is temp_application with pytest.raises(ValueError): - class another_application(Application, identifier="temp_app"): + class another_application(Application, discriminator="temp_app"): pass # This is kinda evil... diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index bdf9cfee..12cb736d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -11,7 +11,7 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP from primaite.utils.validation.port import PORT_LOOKUP -class TestSoftware(Service, identifier="TestSoftware"): +class TestSoftware(Service, discriminator="TestSoftware"): class ConfigSchema(Service.ConfigSchema): """ConfigSChema for TestSoftware."""