Merge agents and actions branches + fix import / subclass errors
This commit is contained in:
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty.
|
||||
- Actions within PrimAITE are now extensible, allowing for plugin support.
|
||||
|
||||
|
||||
## [3.3.0] - 2024-09-04
|
||||
|
||||
|
||||
67
docs/source/how_to_guides/extensible_actions.rst
Normal file
67
docs/source/how_to_guides/extensible_actions.rst
Normal file
@@ -0,0 +1,67 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
|
||||
Extensible Actions
|
||||
******************
|
||||
|
||||
|
||||
Changes to Actions class Structure.
|
||||
===================================
|
||||
|
||||
Actions within PrimAITE have been updated to inherit from a base class, AbstractAction, standardising their format and allowing for easier creation of custom actions. Actions now use a ``ConfigSchema`` to define the possible configuration variables, and use pydantic to enforce correct parameters are passed through.
|
||||
|
||||
|
||||
Developing Custom Actions.
|
||||
==========================
|
||||
|
||||
Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and contain 3 key items:
|
||||
|
||||
#. ConfigSchema class
|
||||
|
||||
#. Unique Identifier
|
||||
|
||||
#. `from_request` method.
|
||||
|
||||
|
||||
ConfigSchema
|
||||
############
|
||||
|
||||
The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file.
|
||||
|
||||
|
||||
Unique Identifier
|
||||
#################
|
||||
|
||||
When declaring a custom class, it must have a unique identifier string, that allows PrimAITE to generate the correct action when needed.
|
||||
|
||||
.. code:: Python
|
||||
|
||||
class CreateDirectoryAction(AbstractAction, identifier="node_folder_create")
|
||||
|
||||
config: CreateDirectoryAction.ConfigSchema
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
node_name: str
|
||||
directory_name: str
|
||||
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
return ["network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"folder",
|
||||
config.directory_name,
|
||||
]
|
||||
|
||||
The above action would fail pydantic validation as the identifier "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`.
|
||||
|
||||
|
||||
from_request method
|
||||
###################
|
||||
|
||||
PrimAITE actions need to be have a `from_request` method, which can be passed to the `RequestManager` for processing. This allows the custom action to be actioned within the simulation environment.
|
||||
57
docs/source/how_to_guides/extensible_rewards.rst
Normal file
57
docs/source/how_to_guides/extensible_rewards.rst
Normal file
@@ -0,0 +1,57 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
Extensible Rewards
|
||||
******************
|
||||
Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward
|
||||
types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE
|
||||
core repository).
|
||||
|
||||
Changes to reward class structure.
|
||||
==================================
|
||||
|
||||
Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel).
|
||||
Within the reward class there is a ConfigSchema class responsible for ensuring the config file data
|
||||
is in the correct format. This also means there is little (if no) requirement for and `__init__`
|
||||
method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`.
|
||||
Each class requires an identifier string which is used by the ConfigSchema class to verify that it
|
||||
hasn't previously been added to the registry.
|
||||
|
||||
Inheriting from `BaseModel` removes the need for an `__init__` method but means that object
|
||||
attributes need to be passed by keyword.
|
||||
|
||||
To add a new reward class follow the example below. Note that the type attribute in the
|
||||
`ConfigSchema` class should match the type used in the config file to define the reward.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
config: "DatabaseFileIntegrity.ConfigSchema"
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseFileIntegrity."""
|
||||
|
||||
type: str = "DATABASE_FILE_INTEGRITY"
|
||||
node_hostname: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Changes to YAML file.
|
||||
=====================
|
||||
.. code:: YAML
|
||||
|
||||
There's no longer a need to provide a `dns_server` as an option in the simulation section
|
||||
of the config file.
|
||||
@@ -55,50 +55,50 @@ agents:
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
- type: do_nothing
|
||||
- type: node_shutdown
|
||||
- type: node_startup
|
||||
- type: host_nic_enable
|
||||
- type: host_nic_enable
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
nic_id: 0
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
nic_id: 0
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
nic_id: 0
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
33
src/primaite/game/agent/actions/__init__.py
Normal file
33
src/primaite/game/agent/actions/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
abstract,
|
||||
acl,
|
||||
application,
|
||||
file,
|
||||
folder,
|
||||
host_nic,
|
||||
manager,
|
||||
network,
|
||||
node,
|
||||
service,
|
||||
session,
|
||||
software,
|
||||
)
|
||||
from primaite.game.agent.actions.manager import ActionManager
|
||||
|
||||
__all__ = (
|
||||
"abstract",
|
||||
"acl",
|
||||
"application",
|
||||
"software",
|
||||
"file",
|
||||
"folder",
|
||||
"host_nic",
|
||||
"manager",
|
||||
"network",
|
||||
"node",
|
||||
"service",
|
||||
"session",
|
||||
"ActionManager",
|
||||
)
|
||||
36
src/primaite/game/agent/actions/abstract.py
Normal file
36
src/primaite/game/agent/actions/abstract.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, ClassVar, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
|
||||
class AbstractAction(BaseModel, ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
config: "AbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Base configuration schema for Actions."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAction]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create new action under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
pass
|
||||
188
src/primaite/game/agent/actions/acl.py
Normal file
188
src/primaite/game/agent/actions/acl.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import protocol_validator
|
||||
from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address
|
||||
from primaite.utils.validation.port import port_validator
|
||||
|
||||
__all__ = (
|
||||
"RouterACLAddRuleAction",
|
||||
"RouterACLRemoveRuleAction",
|
||||
"FirewallACLAddRuleAction",
|
||||
"FirewallACLRemoveRuleAction",
|
||||
)
|
||||
|
||||
|
||||
class ACLAddRuleAbstractAction(AbstractAction, ABC):
|
||||
"""Base abstract class for ACL add rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL add rule abstract actions."""
|
||||
|
||||
src_ip: IPV4Address
|
||||
protocol_name: str
|
||||
permission: str
|
||||
position: int
|
||||
dst_ip: IPV4Address
|
||||
src_port: int
|
||||
dst_port: int
|
||||
src_wildcard: int
|
||||
dst_wildcard: int
|
||||
|
||||
@field_validator(
|
||||
"src_port",
|
||||
"dst_port",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def valid_port(cls, v: str) -> int:
|
||||
"""Check that inputs are valid."""
|
||||
return port_validator(v)
|
||||
|
||||
@field_validator(
|
||||
"src_ip",
|
||||
"dst_ip",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def valid_ip(cls, v: str) -> str:
|
||||
"""Check that a valid IP has been provided for src and dst."""
|
||||
return ipv4_validator(v)
|
||||
|
||||
@field_validator(
|
||||
"protocol_name",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def is_valid_protocol(cls, v: str) -> bool:
|
||||
"""Check that we are using a valid protocol."""
|
||||
return protocol_validator(v)
|
||||
|
||||
|
||||
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
|
||||
"""Base abstract class for ACL remove rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL remove rule abstract actions."""
|
||||
|
||||
position: int
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
config: "RouterACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for RouterACLAddRuleAction."""
|
||||
|
||||
target_router: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_router,
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
config.src_ip,
|
||||
config.src_wildcard,
|
||||
config.src_port,
|
||||
config.dst_ip,
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
config: "RouterACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RouterACLRemoveRuleAction."""
|
||||
|
||||
target_router: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
|
||||
|
||||
|
||||
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
|
||||
"""Action which adds a rule to a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLAddRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
config.src_ip,
|
||||
config.src_wildcard,
|
||||
config.src_port,
|
||||
config.dst_ip,
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
|
||||
"""Action which removes a rule from a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLRemoveRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"remove_rule",
|
||||
config.position,
|
||||
]
|
||||
137
src/primaite/game/agent/actions/application.py
Normal file
137
src/primaite/game/agent/actions/application.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.abstract import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeApplicationExecuteAction",
|
||||
"NodeApplicationScanAction",
|
||||
"NodeApplicationCloseAction",
|
||||
"NodeApplicationFixAction",
|
||||
"NodeApplicationInstallAction",
|
||||
"NodeApplicationRemoveAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeApplicationAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Base class for application actions.
|
||||
|
||||
Any action which applies to an application and uses node_id and application_id as its only two parameters can
|
||||
inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeApplicationAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for Node Application actions."""
|
||||
|
||||
node_name: str
|
||||
application_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"application",
|
||||
config.application_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
|
||||
"""Action which executes an application."""
|
||||
|
||||
config: "NodeApplicationExecuteAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationExecuteAction."""
|
||||
|
||||
verb: str = "execute"
|
||||
|
||||
|
||||
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
|
||||
"""Action which scans an application."""
|
||||
|
||||
config: "NodeApplicationScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationScanAction."""
|
||||
|
||||
verb: str = "scan"
|
||||
|
||||
|
||||
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
|
||||
"""Action which closes an application."""
|
||||
|
||||
config: "NodeApplicationCloseAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationCloseAction."""
|
||||
|
||||
verb: str = "close"
|
||||
|
||||
|
||||
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
|
||||
"""Action which fixes an application."""
|
||||
|
||||
config: "NodeApplicationFixAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationFixAction."""
|
||||
|
||||
verb: str = "fix"
|
||||
|
||||
|
||||
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
|
||||
"""Action which installs an application."""
|
||||
|
||||
config: "NodeApplicationInstallAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationInstallAction."""
|
||||
|
||||
verb: str = "install"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
config.verb,
|
||||
config.application_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
|
||||
"""Action which removes/uninstalls an application."""
|
||||
|
||||
config: "NodeApplicationRemoveAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationRemoveAction."""
|
||||
|
||||
verb: str = "uninstall"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
config.verb,
|
||||
config.application_name,
|
||||
]
|
||||
189
src/primaite/game/agent/actions/file.py
Normal file
189
src/primaite/game/agent/actions/file.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeFileCreateAction",
|
||||
"NodeFileScanAction",
|
||||
"NodeFileDeleteAction",
|
||||
"NodeFileRestoreAction",
|
||||
"NodeFileCorruptAction",
|
||||
"NodeFileAccessAction",
|
||||
"NodeFileCheckhashAction",
|
||||
"NodeFileRepairAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction, ABC):
|
||||
"""Abstract base class for file actions.
|
||||
|
||||
Any action which applies to a file and uses node_name, folder_name, and file_name as its
|
||||
only three parameters can inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeFileAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeFileAbstractAction."""
|
||||
|
||||
node_name: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
"folder",
|
||||
config.folder_name,
|
||||
"file",
|
||||
config.file_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
|
||||
"""Action which creates a new file in a given folder."""
|
||||
|
||||
config: "NodeFileCreateAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCreateAction."""
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
force: bool = False
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"file",
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
|
||||
"""Action which scans a file."""
|
||||
|
||||
config: "NodeFileScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
|
||||
"""Action which deletes a file."""
|
||||
|
||||
config: "NodeFileDeleteAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileDeleteAction."""
|
||||
|
||||
verb: ClassVar[str] = "delete"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"file",
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
|
||||
"""Action which restores a file."""
|
||||
|
||||
config: "NodeFileRestoreAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileRestoreAction."""
|
||||
|
||||
verb: ClassVar[str] = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
|
||||
"""Action which corrupts a file."""
|
||||
|
||||
config: "NodeFileCorruptAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCorruptAction."""
|
||||
|
||||
verb: ClassVar[str] = "corrupt"
|
||||
|
||||
|
||||
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
|
||||
"""Action which increases a file's access count."""
|
||||
|
||||
config: "NodeFileAccessAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileAccessAction."""
|
||||
|
||||
verb: ClassVar[str] = "access"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"):
|
||||
"""Action which checks the hash of a file."""
|
||||
|
||||
config: "NodeFileCheckhashAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCheckhashAction."""
|
||||
|
||||
verb: ClassVar[str] = "checkhash"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
|
||||
"""Action which repairs a file."""
|
||||
|
||||
config: "NodeFileRepairAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeFileRepairAction."""
|
||||
|
||||
verb: ClassVar[str] = "repair"
|
||||
117
src/primaite/game/agent/actions/folder.py
Normal file
117
src/primaite/game/agent/actions/folder.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeFolderScanAction",
|
||||
"NodeFolderCheckhashAction",
|
||||
"NodeFolderRepairAction",
|
||||
"NodeFolderRestoreAction",
|
||||
"NodeFolderCreateAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
|
||||
Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
config: "NodeFolderAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NodeFolder actions."""
|
||||
|
||||
node_name: str
|
||||
folder_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
"folder",
|
||||
config.folder_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
config: "NodeFolderScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
|
||||
"""Action which checks the hash of a folder."""
|
||||
|
||||
config: "NodeFolderCheckhashAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderCheckhashAction."""
|
||||
|
||||
verb: ClassVar[str] = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
|
||||
"""Action which repairs a folder."""
|
||||
|
||||
config: "NodeFolderRepairAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderRepairAction."""
|
||||
|
||||
verb: ClassVar[str] = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
|
||||
"""Action which restores a folder."""
|
||||
|
||||
config: "NodeFolderRestoreAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderRestoreAction."""
|
||||
|
||||
verb: ClassVar[str] = "restore"
|
||||
|
||||
|
||||
class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"):
|
||||
"""Action which creates a new folder."""
|
||||
|
||||
config: "NodeFolderCreateAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderCreateAction."""
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"folder",
|
||||
config.folder_name,
|
||||
]
|
||||
62
src/primaite/game/agent/actions/host_nic.py
Normal file
62
src/primaite/game/agent/actions/host_nic.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("HostNICEnableAction", "HostNICDisableAction")
|
||||
|
||||
|
||||
class HostNICAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
|
||||
class.
|
||||
"""
|
||||
|
||||
config: "HostNICAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for HostNIC actions."""
|
||||
|
||||
node_name: str
|
||||
nic_num: int
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.nic_num is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"network_interface",
|
||||
config.nic_num,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
config: "HostNICEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for HostNICEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
config: "HostNICDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for HostNICDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
138
src/primaite/game/agent/actions/manager.py
Normal file
138
src/primaite/game/agent/actions/manager.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""yaml example.
|
||||
|
||||
agents:
|
||||
- name: agent_1
|
||||
action_space:
|
||||
actions:
|
||||
- do_nothing
|
||||
- node_service_start
|
||||
- node_service_stop
|
||||
action_map:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
# from primaite.game.game import PrimaiteGame # TODO: Breaks things
|
||||
from primaite.game.agent.actions.abstract import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("DoNothingAction", "ActionManager")
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction, identifier="do_nothing"):
|
||||
"""Do Nothing Action."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for do_nothingAction."""
|
||||
|
||||
type: str = "do_nothing"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actions: List[Dict], # stores list of actions available to agent
|
||||
nodes: List[Dict], # extra configuration for each node
|
||||
act_map: Optional[
|
||||
Dict[int, Dict]
|
||||
] = None, # allows restricting set of possible actions - TODO: Refactor to be a list?
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param game: Reference to the game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param actions: List of action specs which should be made available to the agent. The keys of each spec are:
|
||||
'type' and 'options' for passing any options to the action class's init method
|
||||
:type actions: List[dict]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_spec in actions:
|
||||
act_type = act_spec.get("type")
|
||||
self.actions[act_type] = AbstractAction._registry[act_type]
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
For example :
|
||||
{0: ("node_service_scan", {node_name:"client_1", service_name:"WebBrowser"})}
|
||||
"""
|
||||
if act_map is None:
|
||||
# raise RuntimeError("Action map must be specified in the config file.")
|
||||
pass
|
||||
else:
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format."""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_class = AbstractAction._registry[action_identifier]
|
||||
config = act_class.ConfigSchema(**action_options)
|
||||
return act_class.form_request(config=config)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": # noqa: F821
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
The action space config supports the following three sections:
|
||||
1. ``action_list``
|
||||
``action_list`` contains a list action components which need to be included in the action space.
|
||||
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
|
||||
which will be passed to the action class's __init__ method during initialisation.
|
||||
2. ``action_map``
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "node_service_scan" with ``node_name="server"`` and
|
||||
``service_name="WebBrowser"``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param game: The Primaite Game to which the agent belongs.
|
||||
:type game: PrimaiteGame
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
obj = cls(
|
||||
actions=cfg["action_list"],
|
||||
**cfg["options"],
|
||||
protocols=game.options.protocols,
|
||||
ports=game.options.ports,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
return obj
|
||||
57
src/primaite/game/agent/actions/network.py
Normal file
57
src/primaite/game/agent/actions/network.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
|
||||
|
||||
|
||||
class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"):
|
||||
"""Base class for Network port actions."""
|
||||
|
||||
config: "NetworkPortAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NetworkPort actions."""
|
||||
|
||||
target_nodename: str
|
||||
port_id: int
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.target_nodename is None or config.port_id is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_nodename,
|
||||
"network_interface",
|
||||
config.port_id,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"):
|
||||
"""Action which enables are port on a router or a firewall."""
|
||||
|
||||
config: "NetworkPortEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NetworkPortEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
|
||||
"""Action which disables are port on a router or a firewall."""
|
||||
|
||||
config: "NetworkPortDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NetworkPortDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
195
src/primaite/game/agent/actions/node.py
Normal file
195
src/primaite/game/agent/actions/node.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeOSScanAction",
|
||||
"NodeShutdownAction",
|
||||
"NodeStartupAction",
|
||||
"NodeResetAction",
|
||||
"NodeNMAPPingScanAction",
|
||||
"NodeNMAPPortScanAction",
|
||||
"NodeNetworkServiceReconAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction, identifier="node_abstract"):
|
||||
"""
|
||||
Abstract base class for node actions.
|
||||
|
||||
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for Node actions."""
|
||||
|
||||
node_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
print(config)
|
||||
return ["network", "node", config.node_name, config.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"):
|
||||
"""Action which scans a node's OS."""
|
||||
|
||||
config: "NodeOSScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeOSScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"):
|
||||
"""Action which shuts down a node."""
|
||||
|
||||
config: "NodeShutdownAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeShutdownAction."""
|
||||
|
||||
verb: ClassVar[str] = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction, identifier="node_startup"):
|
||||
"""Action which starts up a node."""
|
||||
|
||||
config: "NodeStartupAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeStartupAction."""
|
||||
|
||||
verb: ClassVar[str] = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
|
||||
"""Action which resets a node."""
|
||||
|
||||
config: "NodeResetAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeResetAction."""
|
||||
|
||||
verb: ClassVar[str] = "reset"
|
||||
|
||||
|
||||
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
|
||||
"""Base class for NodeNMAP actions."""
|
||||
|
||||
config: "NodeNMAPAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration Schema for NodeNMAP actions."""
|
||||
|
||||
target_ip_address: Union[str, List[str]]
|
||||
show: bool = False
|
||||
node_name: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
# NMAP action requests don't share a common format for their requests
|
||||
# This is just a placeholder to ensure the method is defined.
|
||||
pass
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
|
||||
"""Action which performs an NMAP ping scan."""
|
||||
|
||||
config: "NodeNMAPPingScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNMAPPingScanAction."""
|
||||
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"application",
|
||||
"NMAP",
|
||||
"ping_scan",
|
||||
{"target_ip_address": config.target_ip_address, "show": config.show},
|
||||
]
|
||||
|
||||
|
||||
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
config: "NodeNMAPPortScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeNMAPPortScanAction."""
|
||||
|
||||
source_node: str
|
||||
target_protocol: Optional[Union[str, List[str]]] = (None,)
|
||||
target_port: Optional[Union[str, List[str]]] = (None,)
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(
|
||||
cls,
|
||||
config: ConfigSchema,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"port_scan",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
config: "NodeNetworkServiceReconAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNetworkServiceReconAction."""
|
||||
|
||||
target_protocol: Optional[Union[str, List[str]]] = (None,)
|
||||
target_port: Optional[Union[str, List[str]]] = (None,)
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(
|
||||
cls,
|
||||
config: ConfigSchema,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"network_service_recon",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
135
src/primaite/game/agent/actions/service.py
Normal file
135
src/primaite/game/agent/actions/service.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeServiceScanAction",
|
||||
"NodeServiceStopAction",
|
||||
"NodeServiceStartAction",
|
||||
"NodeServicePauseAction",
|
||||
"NodeServiceResumeAction",
|
||||
"NodeServiceRestartAction",
|
||||
"NodeServiceDisableAction",
|
||||
"NodeServiceEnableAction",
|
||||
"NodeServiceFixAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
|
||||
"""Abstract Action for Node Service related actions.
|
||||
|
||||
Any actions which use node_name and service_name can inherit from this class.
|
||||
"""
|
||||
|
||||
config: "NodeServiceAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
node_name: str
|
||||
service_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.node_name, "service", config.service_name, config.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
|
||||
"""Action which scans a service."""
|
||||
|
||||
config: "NodeServiceScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
|
||||
"""Action which stops a service."""
|
||||
|
||||
config: "NodeServiceStopAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceStopAction."""
|
||||
|
||||
verb: ClassVar[str] = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
|
||||
"""Action which starts a service."""
|
||||
|
||||
config: "NodeServiceStartAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceStartAction."""
|
||||
|
||||
verb: ClassVar[str] = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
|
||||
"""Action which pauses a service."""
|
||||
|
||||
config: "NodeServicePauseAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServicePauseAction."""
|
||||
|
||||
verb: ClassVar[str] = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
|
||||
"""Action which resumes a service."""
|
||||
|
||||
config: "NodeServiceResumeAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceResumeAction."""
|
||||
|
||||
verb: ClassVar[str] = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
|
||||
"""Action which restarts a service."""
|
||||
|
||||
config: "NodeServiceRestartAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceRestartAction."""
|
||||
|
||||
verb: ClassVar[str] = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
|
||||
"""Action which disables a service."""
|
||||
|
||||
config: "NodeServiceDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
|
||||
"""Action which enables a service."""
|
||||
|
||||
config: "NodeServiceEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
|
||||
"""Action which fixes a service."""
|
||||
|
||||
config: "NodeServiceFixAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceFixAction."""
|
||||
|
||||
verb: ClassVar[str] = "fix"
|
||||
108
src/primaite/game/agent/actions/session.py
Normal file
108
src/primaite/game/agent/actions/session.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeSessionsRemoteLoginAction",
|
||||
"NodeSessionsRemoteLogoutAction",
|
||||
"NodeAccountChangePasswordAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"):
|
||||
"""Base class for NodeSession actions."""
|
||||
|
||||
config: "NodeSessionAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NodeSessionAbstractActions."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""
|
||||
Abstract method for request forming.
|
||||
|
||||
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
|
||||
"""Action which performs a remote session login."""
|
||||
|
||||
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLoginAction."""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"ssh_to_remote",
|
||||
config.username,
|
||||
config.password,
|
||||
config.remote_ip,
|
||||
]
|
||||
|
||||
|
||||
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"):
|
||||
"""Action which performs a remote session logout."""
|
||||
|
||||
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
|
||||
|
||||
verb: str = "remote_logoff"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
|
||||
|
||||
|
||||
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):
|
||||
"""Action which changes the password for a user."""
|
||||
|
||||
config: "NodeAccountChangePasswordAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeAccountsChangePasswordAction."""
|
||||
|
||||
username: str
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"UserManager",
|
||||
"change_password",
|
||||
config.username,
|
||||
config.current_password,
|
||||
config.new_password,
|
||||
]
|
||||
238
src/primaite/game/agent/actions/software.py
Normal file
238
src/primaite/game/agent/actions/software.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"ConfigureRansomwareScriptAction",
|
||||
"ConfigureDoSBotAction",
|
||||
"ConfigureC2BeaconAction",
|
||||
"NodeSendRemoteCommandAction",
|
||||
"TerminalC2ServerAction",
|
||||
"RansomwareLaunchC2ServerAction",
|
||||
"ExfiltrationC2ServerAction",
|
||||
"ConfigureDatabaseClientAction",
|
||||
)
|
||||
|
||||
|
||||
class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_ransomware_configure"):
|
||||
"""Action which sets config parameters for a ransomware script on a node."""
|
||||
|
||||
config: "ConfigureRansomwareScriptAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureRansomwareScriptAction."""
|
||||
|
||||
node_name: str
|
||||
server_ip_address: Optional[str]
|
||||
server_password: Optional[str]
|
||||
payload: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"application",
|
||||
"RansomwareScript",
|
||||
"configure",
|
||||
config.model_config,
|
||||
]
|
||||
|
||||
|
||||
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
|
||||
"""Action which sets config parameters for a DoS bot on a node."""
|
||||
|
||||
config: "ConfigureDoSBotAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
target_ip_address: Optional[str] = None
|
||||
target_port: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
repeat: Optional[bool] = None
|
||||
port_scan_p_of_success: Optional[float] = None
|
||||
dos_intensity: Optional[float] = None
|
||||
max_sessions: Optional[int] = None
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
self.ConfigSchema.model_validate(config) # check that options adhere to schema
|
||||
return ["network", "node", config.node_name, "application", "DoSBot", "configure", config]
|
||||
|
||||
|
||||
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
|
||||
"""Action which configures a C2 Beacon based on the parameters given."""
|
||||
|
||||
config: "ConfigureC2BeaconAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureC2BeaconAction."""
|
||||
|
||||
node_name: str
|
||||
c2_server_ip_address: str
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
masquerade_protocol: str = Field(default="TCP")
|
||||
masquerade_port: str = Field(default="HTTP")
|
||||
|
||||
@field_validator(
|
||||
"c2_server_ip_address",
|
||||
"keep_alive_frequency",
|
||||
"masquerade_protocol",
|
||||
"masquerade_port",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def not_none(cls, v: str, info: ValidationInfo) -> int:
|
||||
"""If None is passed, use the default value instead."""
|
||||
if v is None:
|
||||
return cls.model_fields[info.field_name].default
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config]
|
||||
|
||||
|
||||
class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):
|
||||
"""Action which sends a terminal command to a remote node via SSH."""
|
||||
|
||||
config: "NodeSendRemoteCommandAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSendRemoteCommandAction."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
command: RequestFormat
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"send_remote_command",
|
||||
config.remote_ip,
|
||||
{"command": config.command},
|
||||
]
|
||||
|
||||
|
||||
class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"):
|
||||
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
|
||||
|
||||
config: "TerminalC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
commands: Union[List[RequestFormat], RequestFormat]
|
||||
ip_address: Optional[str]
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
|
||||
command_model = {
|
||||
"commands": config.commands,
|
||||
"ip_address": config.ip_address,
|
||||
"username": config.username,
|
||||
"password": config.password,
|
||||
}
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model]
|
||||
|
||||
|
||||
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"):
|
||||
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
|
||||
|
||||
config: "RansomwareLaunchC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RansomwareLaunchC2ServerAction."""
|
||||
|
||||
node_name: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
# This action currently doesn't require any further configuration options.
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
|
||||
|
||||
|
||||
class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"):
|
||||
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
|
||||
|
||||
config: "ExfiltrationC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
target_ip_address: str
|
||||
target_file_name: str
|
||||
target_folder_name: str
|
||||
exfiltration_folder_name: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
|
||||
command_model = {
|
||||
"target_file_name": config.target_file_name,
|
||||
"target_folder_name": config.target_folder_name,
|
||||
"exfiltration_folder_name": config.exfiltration_folder_name,
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"username": config.username,
|
||||
"password": config.password,
|
||||
}
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model]
|
||||
|
||||
|
||||
class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"):
|
||||
"""Action which sets config parameters for a database client on a node."""
|
||||
|
||||
config: "ConfigureDatabaseClientAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", config.model_config]
|
||||
@@ -73,11 +73,13 @@ class AbstractAgent(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
type: str = "AbstractAgent"
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
|
||||
@@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs
|
||||
__all__ = [
|
||||
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
|
||||
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",]
|
||||
# fmt: on
|
||||
|
||||
@@ -31,7 +31,7 @@ class AbstractObservation(ABC):
|
||||
"""Initialise an observation. This method must be overwritten."""
|
||||
self.default_observation: ObsType
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
|
||||
@@ -40,6 +40,8 @@ class AbstractObservation(ABC):
|
||||
:raises ValueError: When attempting to create a component with a name that is already in use.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate observation component type {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -27,9 +27,10 @@ the structure:
|
||||
service_ref: web_server_database_client
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -42,25 +43,28 @@ _LOGGER = getLogger(__name__)
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
class AbstractReward(BaseModel):
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
config: "AbstractReward.ConfigSchema"
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return 0.0
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Config schema for AbstractReward."""
|
||||
|
||||
type: str
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate reward {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict) -> "AbstractReward":
|
||||
def from_config(cls, config: Dict) -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
@@ -68,11 +72,28 @@ class AbstractReward:
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
return cls()
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid reward type {config['type']}")
|
||||
reward_class = cls._registry[config["type"]]
|
||||
reward_obj = reward_class(config=reward_class.ConfigSchema(**config))
|
||||
return reward_obj
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
class DummyReward(AbstractReward, identifier="DUMMY"):
|
||||
"""Dummy reward function component which always returns 0.0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
@@ -86,41 +107,21 @@ class DummyReward(AbstractReward):
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
:type config: dict
|
||||
:return: The reward component.
|
||||
:rtype: DummyReward
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
config: "DatabaseFileIntegrity.ConfigSchema"
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
:param node_hostname: Hostname of the node which contains the database file.
|
||||
:type node_hostname: str
|
||||
:param folder_name: folder which contains the database file.
|
||||
:type folder_name: str
|
||||
:param file_name: name of the database file.
|
||||
:type file_name: str
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
node_hostname,
|
||||
"file_system",
|
||||
"folders",
|
||||
folder_name,
|
||||
"files",
|
||||
file_name,
|
||||
]
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseFileIntegrity."""
|
||||
|
||||
type: str = "DATABASE_FILE_INTEGRITY"
|
||||
node_hostname: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
@@ -132,6 +133,17 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.config.node_hostname,
|
||||
"file_system",
|
||||
"folders",
|
||||
self.config.folder_name,
|
||||
"files",
|
||||
self.config.file_name,
|
||||
]
|
||||
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.debug(
|
||||
@@ -148,44 +160,21 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
else:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
folder_name = config.get("folder_name")
|
||||
file_name = config.get("file_name")
|
||||
if not (node_hostname and folder_name and file_name):
|
||||
msg = f"{cls.__name__} could not be initialised with parameters {config}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
|
||||
"""Initialise the reward component.
|
||||
config: "WebServer404Penalty.ConfigSchema"
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
:param node_hostname: Hostname of the node which contains the web server service.
|
||||
:type node_hostname: str
|
||||
:param service_name: Name of the web server service.
|
||||
:type service_name: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebServer404Penalty."""
|
||||
|
||||
type: str = "WEB_SERVER_404_PENALTY"
|
||||
node_hostname: str
|
||||
service_name: str
|
||||
sticky: bool = True
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
@@ -197,6 +186,13 @@ class WebServer404Penalty(AbstractReward):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.config.node_hostname,
|
||||
"services",
|
||||
self.config.service_name,
|
||||
]
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# if webserver is no longer installed on the node, return 0
|
||||
@@ -211,54 +207,27 @@ class WebServer404Penalty(AbstractReward):
|
||||
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
|
||||
|
||||
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
|
||||
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
|
||||
elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0
|
||||
self.reward = 0.0
|
||||
else: # skip calculating if sticky and no new codes. instead, reuse last step's value
|
||||
pass
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
service_name = config.get("service_name")
|
||||
if not (node_hostname and service_name):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised from config because node_name and service_ref were not "
|
||||
"found in reward config."
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
raise ValueError(msg)
|
||||
sticky = config.get("sticky", True)
|
||||
|
||||
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
|
||||
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward):
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
config: "WebpageUnavailablePenalty.ConfigSchema"
|
||||
reward: float = 0.0
|
||||
location_in_state: List[str] = [""] # Calculate in __init__()?
|
||||
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebpageUnavailablePenalty."""
|
||||
|
||||
type: str = "WEBPAGE_UNAVAILABLE_PENALTY"
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -274,6 +243,13 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.config.node_hostname,
|
||||
"applications",
|
||||
"WebBrowser",
|
||||
]
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -283,14 +259,14 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self._node,
|
||||
self.config.node_hostname,
|
||||
"application",
|
||||
"WebBrowser",
|
||||
"execute",
|
||||
]
|
||||
|
||||
# skip calculating if sticky and no new codes, reusing last step value
|
||||
if not request_attempted and self.sticky:
|
||||
if not request_attempted and self.config.sticky:
|
||||
return self.reward
|
||||
|
||||
if last_action_response.response.status != "success":
|
||||
@@ -298,7 +274,7 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
|
||||
_LOGGER.debug(
|
||||
"Web browser reward could not be calculated because the web browser history on node",
|
||||
f"{self._node} was not reported in the simulation state. Returning 0.0",
|
||||
f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0",
|
||||
)
|
||||
self.reward = 0.0
|
||||
else:
|
||||
@@ -312,37 +288,19 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> AbstractReward:
|
||||
"""
|
||||
Build the reward component object from config.
|
||||
|
||||
:param config: Configuration dictionary.
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema"
|
||||
reward: float = 0.0
|
||||
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
|
||||
|
||||
type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"
|
||||
node_hostname: str
|
||||
sticky: bool = True
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -362,7 +320,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self._node,
|
||||
self.config.node_hostname,
|
||||
"application",
|
||||
"DatabaseClient",
|
||||
"execute",
|
||||
@@ -371,7 +329,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
if request_attempted: # if agent makes request, always recalculate fresh value
|
||||
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
elif not self.sticky: # if no new request and not sticky, set reward to 0
|
||||
elif not self.config.sticky: # if no new request and not sticky, set reward to 0
|
||||
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
|
||||
self.reward = 0.0
|
||||
else: # if no new request and sticky, reuse reward value from last step
|
||||
@@ -380,47 +338,30 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
"""
|
||||
Build the reward component object from config.
|
||||
|
||||
:param config: Configuration dictionary.
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
config: "SharedReward.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for SharedReward."""
|
||||
|
||||
type: str = "SHARED_REWARD"
|
||||
agent_name: str
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Initialise the shared reward.
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
correctly.
|
||||
|
||||
:param agent_name: The name whose reward is an input
|
||||
:type agent_name: Optional[str]
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
"""Agent whose reward to track."""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it.
|
||||
@@ -432,36 +373,25 @@ class SharedReward(AbstractReward):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "SharedReward":
|
||||
"""
|
||||
Build the SharedReward object from config.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:type config: Dict
|
||||
"""
|
||||
agent_name = config.get("agent_name")
|
||||
return cls(agent_name=agent_name)
|
||||
return self.callback(self.config.agent_name)
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward):
|
||||
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
"""
|
||||
Initialise the reward.
|
||||
config: "ActionPenalty.ConfigSchema"
|
||||
|
||||
Reward or penalise agents for doing nothing or taking actions.
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for ActionPenalty.
|
||||
|
||||
:param action_penalty: Reward to give agents for taking any action except DONOTHING
|
||||
:param action_penalty: Reward to give agents for taking any action except do_nothing
|
||||
:type action_penalty: float
|
||||
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
|
||||
:param do_nothing_penalty: Reward to give agent for taking the do_nothing action
|
||||
:type do_nothing_penalty: float
|
||||
"""
|
||||
self.action_penalty = action_penalty
|
||||
self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied.
|
||||
@@ -473,33 +403,16 @@ class ActionPenalty(AbstractReward):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
if last_action_response.action == "do_nothing":
|
||||
return self.do_nothing_penalty
|
||||
else:
|
||||
return self.action_penalty
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "ActionPenalty":
|
||||
"""Build the ActionPenalty object from config."""
|
||||
action_penalty = config.get("action_penalty", -1.0)
|
||||
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
|
||||
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
|
||||
else:
|
||||
return self.config.action_penalty
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
"SHARED_REWARD": SharedReward,
|
||||
"ACTION_PENALTY": ActionPenalty,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
@@ -534,7 +447,7 @@ class RewardFunction:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
"""Create a reward function from a config dictionary and its related reward class.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
:type config: Dict
|
||||
@@ -545,8 +458,11 @@ class RewardFunction:
|
||||
|
||||
for rew_component_cfg in config["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
# XXX: If options key is missing add key then add type key.
|
||||
if "options" not in rew_component_cfg:
|
||||
rew_component_cfg["options"] = {}
|
||||
rew_component_cfg["options"]["type"] = rew_type
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}))
|
||||
rew_instance = AbstractReward.from_config(rew_component_cfg["options"])
|
||||
new.register_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
|
||||
@@ -370,7 +370,7 @@ class PrimaiteGame:
|
||||
|
||||
if service_class is not None:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_class)
|
||||
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
|
||||
new_service = new_node.software_manager.software[service_class.__name__]
|
||||
|
||||
# fixing duration for the service
|
||||
@@ -580,7 +580,7 @@ class PrimaiteGame:
|
||||
for comp, weight in agent.reward_function.reward_components:
|
||||
if isinstance(comp, SharedReward):
|
||||
comp: SharedReward
|
||||
graph[name].add(comp.agent_name)
|
||||
graph[name].add(comp.config.agent_name)
|
||||
|
||||
# while constructing the graph, we might as well set up the reward sharing itself.
|
||||
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable\n"
|
||||
"from prettytable import PrettyTable"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -195,7 +195,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -209,7 +209,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1780,10 +1780,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
|
||||
"from primaite.utils.validation.port import PORT_LOOKUP\n",
|
||||
"\n",
|
||||
"# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n",
|
||||
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n",
|
||||
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n",
|
||||
"c2_beacon.establish()\n",
|
||||
"c2_beacon.show()"
|
||||
]
|
||||
@@ -1804,7 +1805,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -1818,7 +1819,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
@@ -165,13 +165,13 @@
|
||||
"\n",
|
||||
"| node_id | node name |\n",
|
||||
"|---------|------------------|\n",
|
||||
"| 1 | domain_controller|\n",
|
||||
"| 2 | web_server |\n",
|
||||
"| 3 | database_server |\n",
|
||||
"| 4 | backup_server |\n",
|
||||
"| 5 | security_suite |\n",
|
||||
"| 6 | client_1 |\n",
|
||||
"| 7 | client_2 |\n",
|
||||
"| 0 | domain_controller|\n",
|
||||
"| 1 | web_server |\n",
|
||||
"| 2 | database_server |\n",
|
||||
"| 3 | backup_server |\n",
|
||||
"| 4 | security_suite |\n",
|
||||
"| 5 | client_1 |\n",
|
||||
"| 6 | client_2 |\n",
|
||||
"\n",
|
||||
"Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n",
|
||||
"\n",
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -109,7 +109,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -168,7 +168,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -182,7 +182,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -166,9 +166,10 @@
|
||||
"from pathlib import Path\n",
|
||||
"from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n",
|
||||
"from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.file_system.file_system import FileSystem\n",
|
||||
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
|
||||
"from primaite.utils.validation.port import PORT_LOOKUP\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# no applications exist yet so we will create our own.\n",
|
||||
"class MSPaint(Application, identifier=\"MSPaint\"):\n",
|
||||
@@ -182,7 +183,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
|
||||
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=PORT_LOOKUP[\"HTTP\"], protocol = PROTOCOL_LOOKUP[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -249,7 +250,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -263,7 +264,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -532,12 +532,12 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
|
||||
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
|
||||
"\n",
|
||||
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
|
||||
" action=ACLAction.DENY,\n",
|
||||
" protocol=IPProtocol["ICMP"],\n",
|
||||
" protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n",
|
||||
" src_ip_address=\"192.168.10.22\",\n",
|
||||
" position=1\n",
|
||||
")"
|
||||
@@ -650,7 +650,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Literal, Type
|
||||
from typing import Any, ClassVar, Dict, Literal, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -49,7 +49,7 @@ class NetworkNodeAdder(BaseModel):
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a network node adder class.
|
||||
|
||||
@@ -58,6 +58,8 @@ class NetworkNodeAdder(BaseModel):
|
||||
:raises ValueError: When attempting to register a name that is already reserved.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate node adder {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -1545,7 +1545,7 @@ class Node(SimComponent):
|
||||
_identifier: ClassVar[str] = "unknown"
|
||||
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a node type.
|
||||
|
||||
@@ -1553,10 +1553,10 @@ class Node(SimComponent):
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an node with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
identifier = identifier.lower()
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -44,7 +44,7 @@ class Application(IOSoftware):
|
||||
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
|
||||
"""Registry of application types. Automatically populated when subclasses are defined."""
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an application type.
|
||||
|
||||
@@ -52,9 +52,9 @@ class Application(IOSoftware):
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an application with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
return
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -308,6 +308,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
if self.server_ip_address is None:
|
||||
self.sys_log.warning(f"{self.name}: Database server IP address not provided.")
|
||||
return None
|
||||
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
|
||||
@@ -16,7 +16,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface, Node
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class Service(IOSoftware):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a hostnode type.
|
||||
|
||||
@@ -60,11 +60,11 @@ class Service(IOSoftware):
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
|
||||
identifier = identifier.lower()
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -31,7 +31,7 @@ def ipv4_validator(v: Any) -> IPv4Address:
|
||||
|
||||
IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)]
|
||||
"""
|
||||
IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator..
|
||||
IPv4Address with pre-validation and auto-conversion from str using ipv4_validator..
|
||||
|
||||
This type is essentially an IPv4Address from the standard library's ipaddress module,
|
||||
but with added validation logic. If you use this custom type, the ipv4_validator function
|
||||
|
||||
@@ -205,8 +205,6 @@ simulation:
|
||||
port_scan_p_of_success: 0.8
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
|
||||
@@ -33,7 +33,7 @@ agents:
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: do_nothing
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
@@ -47,7 +47,7 @@ agents:
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
@@ -82,7 +82,7 @@ agents:
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: do_nothing
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
@@ -96,7 +96,7 @@ agents:
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
@@ -132,7 +132,7 @@ agents:
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: do_nothing
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
@@ -235,7 +235,7 @@ agents:
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: do_nothing
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
@@ -265,7 +265,7 @@ agents:
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
|
||||
@@ -96,155 +96,158 @@ agents:
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: FIREWALL_ACL_ADDRULE
|
||||
- type: FIREWALL_ACL_REMOVERULE
|
||||
- type: NETWORK_PORT_DISABLE
|
||||
- type: NETWORK_PORT_ENABLE
|
||||
- type: do_nothing
|
||||
- type: firewall_acl_add_rule
|
||||
- type: firewall_acl_remove_rule
|
||||
- type: network_port_disable
|
||||
- type: network_port_enable
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: FIREWALL_ACL_ADDRULE
|
||||
action: firewall_acl_add_rule
|
||||
options:
|
||||
type: firewall_acl_add_rule
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: internal
|
||||
firewall_port_direction: inbound
|
||||
position: 1
|
||||
permission: 1
|
||||
source_ip_id: 2 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: PERMIT
|
||||
src_ip: 192.168.0.10
|
||||
dst_ip: 0.0.0.0
|
||||
src_port: 80
|
||||
dst_port: HTTP
|
||||
protocol_name: TCP
|
||||
src_wildcard: 0
|
||||
dst_wildcard: 0
|
||||
2:
|
||||
action: FIREWALL_ACL_REMOVERULE
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: internal
|
||||
firewall_port_direction: inbound
|
||||
position: 1
|
||||
3:
|
||||
action: FIREWALL_ACL_ADDRULE
|
||||
action: firewall_acl_add_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: internal
|
||||
firewall_port_direction: outbound
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 2 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 2
|
||||
dest_port_id: 3
|
||||
protocol_id: 2
|
||||
permission: DENY
|
||||
src_ip: 192.168.0.10 # client 1
|
||||
dest_ip: ALL
|
||||
src_port: ARP
|
||||
dst_port: DNS
|
||||
protocol_name: ICMP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
4:
|
||||
action: FIREWALL_ACL_REMOVERULE
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: internal
|
||||
firewall_port_direction: outbound
|
||||
position: 1
|
||||
5:
|
||||
action: FIREWALL_ACL_ADDRULE
|
||||
action: firewall_acl_add_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: dmz
|
||||
firewall_port_direction: inbound
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 3 # dmz_server
|
||||
dest_ip_id: 2 # client_1
|
||||
source_port_id: 4
|
||||
dest_port_id: 4
|
||||
protocol_id: 4
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.10 # dmz_server
|
||||
dest_ip: 192.168.0.10 # client_1
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
protocol_name: UDP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
6:
|
||||
action: FIREWALL_ACL_REMOVERULE
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: dmz
|
||||
firewall_port_direction: inbound
|
||||
position: 1
|
||||
7:
|
||||
action: FIREWALL_ACL_ADDRULE
|
||||
action: firewall_acl_add_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: dmz
|
||||
firewall_port_direction: outbound
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 3 # dmz_server
|
||||
dest_ip_id: 2 # client_1
|
||||
source_port_id: 4
|
||||
dest_port_id: 4
|
||||
protocol_id: 3
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.10 # dmz_server
|
||||
dest_ip: 192.168.0.10 # client_1
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
protocol_name: TCP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
8:
|
||||
action: FIREWALL_ACL_REMOVERULE
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: dmz
|
||||
firewall_port_direction: outbound
|
||||
position: 2
|
||||
9:
|
||||
action: FIREWALL_ACL_ADDRULE
|
||||
action: firewall_acl_add_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: external
|
||||
firewall_port_direction: inbound
|
||||
position: 10
|
||||
permission: 2
|
||||
source_ip_id: 4 # external_computer
|
||||
dest_ip_id: 3 # dmz
|
||||
source_port_id: 5
|
||||
dest_port_id: 5
|
||||
protocol_id: 2
|
||||
permission: DENY
|
||||
src_ip: 192.168.20.10 # external_computer
|
||||
dest_ip: 192.168.10.10 # dmz
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
protocol_name: ICMP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
10:
|
||||
action: FIREWALL_ACL_REMOVERULE
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: external
|
||||
firewall_port_direction: inbound
|
||||
position: 10
|
||||
11:
|
||||
action: FIREWALL_ACL_ADDRULE
|
||||
action: firewall_acl_add_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: external
|
||||
firewall_port_direction: outbound
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 4 # external_computer
|
||||
dest_ip_id: 2 # client_1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
permission: DENY
|
||||
src_ip: 192.168.20.10 # external_computer
|
||||
dest_ip: 192.168.0.10 # client_1
|
||||
src_port: NONE
|
||||
dst_port: NONE
|
||||
protocol_name: none
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
12:
|
||||
action: FIREWALL_ACL_REMOVERULE
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
target_firewall_nodename: firewall
|
||||
firewall_port_name: external
|
||||
firewall_port_direction: outbound
|
||||
position: 1
|
||||
13:
|
||||
action: NETWORK_PORT_DISABLE
|
||||
action: network_port_disable
|
||||
options:
|
||||
type: network_port_disable
|
||||
target_nodename: firewall
|
||||
port_id: 3
|
||||
14:
|
||||
action: NETWORK_PORT_ENABLE
|
||||
action: network_port_enable
|
||||
options:
|
||||
type: network_port_enable
|
||||
target_nodename: firewall
|
||||
port_id: 3
|
||||
options:
|
||||
|
||||
@@ -201,8 +201,6 @@ simulation:
|
||||
port_scan_p_of_success: 0.8
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
@@ -233,8 +231,6 @@ simulation:
|
||||
server_password: arcd
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: switch_1
|
||||
|
||||
@@ -34,15 +34,16 @@ agents:
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_NETWORK_SERVICE_RECON
|
||||
- type: node_network_service_recon
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_NETWORK_SERVICE_RECON
|
||||
action: node_network_service_recon
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.10.0/24
|
||||
target_port: 80
|
||||
target_protocol: tcp
|
||||
show: false
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -34,13 +34,14 @@ agents:
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_PING_SCAN
|
||||
- type: node_nmap_ping_scan
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_PING_SCAN
|
||||
action: node_nmap_ping_scan
|
||||
options:
|
||||
source_node: client_1
|
||||
node_name: client_1
|
||||
target_ip_address: 192.168.1.0/24
|
||||
show: False
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -34,19 +34,21 @@ agents:
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
action_list:
|
||||
- type: NODE_NMAP_PORT_SCAN
|
||||
- type: node_nmap_port_scan
|
||||
action_map:
|
||||
0:
|
||||
action: NODE_NMAP_PORT_SCAN
|
||||
action: node_nmap_port_scan
|
||||
options:
|
||||
source_node: client_1
|
||||
target_ip_address: 192.168.10.0/24
|
||||
target_protocol: tcp
|
||||
target_port:
|
||||
- 21
|
||||
- 53
|
||||
- 80
|
||||
- 123
|
||||
- 219
|
||||
show: false
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -210,7 +210,6 @@ simulation:
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
fix_duration: 3
|
||||
- type: DNSServer
|
||||
options:
|
||||
@@ -251,8 +250,6 @@ simulation:
|
||||
server_password: arcd
|
||||
services:
|
||||
- type: DNSClient
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: switch_1
|
||||
|
||||
@@ -415,85 +415,58 @@ def game_and_agent():
|
||||
install_stuff_to_sim(sim)
|
||||
|
||||
actions = [
|
||||
{"type": "DONOTHING"},
|
||||
{"type": "NODE_SERVICE_SCAN"},
|
||||
{"type": "NODE_SERVICE_STOP"},
|
||||
{"type": "NODE_SERVICE_START"},
|
||||
{"type": "NODE_SERVICE_PAUSE"},
|
||||
{"type": "NODE_SERVICE_RESUME"},
|
||||
{"type": "NODE_SERVICE_RESTART"},
|
||||
{"type": "NODE_SERVICE_DISABLE"},
|
||||
{"type": "NODE_SERVICE_ENABLE"},
|
||||
{"type": "NODE_SERVICE_FIX"},
|
||||
{"type": "NODE_APPLICATION_EXECUTE"},
|
||||
{"type": "NODE_APPLICATION_SCAN"},
|
||||
{"type": "NODE_APPLICATION_CLOSE"},
|
||||
{"type": "NODE_APPLICATION_FIX"},
|
||||
{"type": "NODE_APPLICATION_INSTALL"},
|
||||
{"type": "NODE_APPLICATION_REMOVE"},
|
||||
{"type": "NODE_FILE_CREATE"},
|
||||
{"type": "NODE_FILE_SCAN"},
|
||||
{"type": "NODE_FILE_CHECKHASH"},
|
||||
{"type": "NODE_FILE_DELETE"},
|
||||
{"type": "NODE_FILE_REPAIR"},
|
||||
{"type": "NODE_FILE_RESTORE"},
|
||||
{"type": "NODE_FILE_CORRUPT"},
|
||||
{"type": "NODE_FILE_ACCESS"},
|
||||
{"type": "NODE_FOLDER_CREATE"},
|
||||
{"type": "NODE_FOLDER_SCAN"},
|
||||
{"type": "NODE_FOLDER_CHECKHASH"},
|
||||
{"type": "NODE_FOLDER_REPAIR"},
|
||||
{"type": "NODE_FOLDER_RESTORE"},
|
||||
{"type": "NODE_OS_SCAN"},
|
||||
{"type": "NODE_SHUTDOWN"},
|
||||
{"type": "NODE_STARTUP"},
|
||||
{"type": "NODE_RESET"},
|
||||
{"type": "ROUTER_ACL_ADDRULE"},
|
||||
{"type": "ROUTER_ACL_REMOVERULE"},
|
||||
{"type": "HOST_NIC_ENABLE"},
|
||||
{"type": "HOST_NIC_DISABLE"},
|
||||
{"type": "NETWORK_PORT_ENABLE"},
|
||||
{"type": "NETWORK_PORT_DISABLE"},
|
||||
{"type": "CONFIGURE_C2_BEACON"},
|
||||
{"type": "C2_SERVER_RANSOMWARE_LAUNCH"},
|
||||
{"type": "C2_SERVER_RANSOMWARE_CONFIGURE"},
|
||||
{"type": "C2_SERVER_TERMINAL_COMMAND"},
|
||||
{"type": "C2_SERVER_DATA_EXFILTRATE"},
|
||||
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
|
||||
{"type": "SSH_TO_REMOTE"},
|
||||
{"type": "SESSIONS_REMOTE_LOGOFF"},
|
||||
{"type": "NODE_SEND_REMOTE_COMMAND"},
|
||||
{"type": "do_nothing"},
|
||||
{"type": "node_service_scan"},
|
||||
{"type": "node_service_stop"},
|
||||
{"type": "node_service_start"},
|
||||
{"type": "node_service_pause"},
|
||||
{"type": "node_service_resume"},
|
||||
{"type": "node_service_restart"},
|
||||
{"type": "node_service_disable"},
|
||||
{"type": "node_service_enable"},
|
||||
{"type": "node_service_fix"},
|
||||
{"type": "node_application_execute"},
|
||||
{"type": "node_application_scan"},
|
||||
{"type": "node_application_close"},
|
||||
{"type": "node_application_fix"},
|
||||
{"type": "node_application_install"},
|
||||
{"type": "node_application_remove"},
|
||||
{"type": "node_file_create"},
|
||||
{"type": "node_file_scan"},
|
||||
{"type": "node_file_checkhash"},
|
||||
{"type": "node_file_delete"},
|
||||
{"type": "node_file_repair"},
|
||||
{"type": "node_file_restore"},
|
||||
{"type": "node_file_corrupt"},
|
||||
{"type": "node_file_access"},
|
||||
{"type": "node_folder_create"},
|
||||
{"type": "node_folder_scan"},
|
||||
{"type": "node_folder_checkhash"},
|
||||
{"type": "node_folder_repair"},
|
||||
{"type": "node_folder_restore"},
|
||||
{"type": "node_os_scan"},
|
||||
{"type": "node_shutdown"},
|
||||
{"type": "node_startup"},
|
||||
{"type": "node_reset"},
|
||||
{"type": "router_acl_add_rule"},
|
||||
{"type": "router_acl_remove_rule"},
|
||||
{"type": "host_nic_enable"},
|
||||
{"type": "host_nic_disable"},
|
||||
{"type": "network_port_enable"},
|
||||
{"type": "network_port_disable"},
|
||||
{"type": "configure_c2_beacon"},
|
||||
{"type": "c2_server_ransomware_launch"},
|
||||
{"type": "c2_server_ransomware_configure"},
|
||||
{"type": "c2_server_terminal_command"},
|
||||
{"type": "c2_server_data_exfiltrate"},
|
||||
{"type": "node_account_change_password"},
|
||||
{"type": "node_session_remote_login"},
|
||||
{"type": "node_session_remote_logoff"},
|
||||
{"type": "node_send_remote_command"},
|
||||
]
|
||||
|
||||
action_space = ActionManager(
|
||||
actions=actions, # ALL POSSIBLE ACTIONS
|
||||
nodes=[
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"applications": [
|
||||
{"application_name": "WebBrowser"},
|
||||
{"application_name": "DoSBot"},
|
||||
{"application_name": "C2Server"},
|
||||
],
|
||||
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
|
||||
},
|
||||
{
|
||||
"node_name": "server_1",
|
||||
"services": [{"service_name": "DNSServer"}],
|
||||
"applications": [{"application_name": "C2Beacon"}],
|
||||
},
|
||||
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
|
||||
{"node_name": "router"},
|
||||
],
|
||||
max_folders_per_node=2,
|
||||
max_files_per_folder=2,
|
||||
max_services_per_node=2,
|
||||
max_applications_per_node=3,
|
||||
max_nics_per_node=2,
|
||||
max_acl_rules=10,
|
||||
protocols=["TCP", "UDP", "ICMP"],
|
||||
ports=["HTTP", "DNS", "ARP"],
|
||||
ip_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
|
||||
act_map={},
|
||||
)
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
|
||||
@@ -5,6 +5,7 @@ from primaite.config.load import get_extended_config_path
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config
|
||||
from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication
|
||||
from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
|
||||
@@ -13,11 +14,12 @@ from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
|
||||
from tests.integration_tests.extensions.nodes.super_computer import SuperComputer
|
||||
from tests.integration_tests.extensions.services.extended_service import ExtendedService
|
||||
|
||||
CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml"
|
||||
|
||||
|
||||
def test_extended_example_config():
|
||||
"""Test that the example config can be parsed properly."""
|
||||
config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml")
|
||||
game = load_config(config_path)
|
||||
game = load_config(CONFIG_PATH)
|
||||
network: Network = game.simulation.network
|
||||
|
||||
assert len(network.nodes) == 10 # 10 nodes in example network
|
||||
|
||||
@@ -134,7 +134,7 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA
|
||||
|
||||
# Stepping a few timesteps to allow for the RansowmareScript to finish installing.
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
action = ("do_nothing", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
game.step()
|
||||
|
||||
@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
from primaite.game.agent.actions.software import (
|
||||
ConfigureDatabaseClientAction,
|
||||
ConfigureDoSBotAction,
|
||||
ConfigureRansomwareScriptAction,
|
||||
@@ -35,10 +35,10 @@ class TestConfigureDatabaseAction:
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
"configure_database_client",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"node_name": "client_1",
|
||||
"model_config": {
|
||||
"server_ip_address": "192.168.1.99",
|
||||
"server_password": "admin123",
|
||||
},
|
||||
@@ -53,7 +53,7 @@ class TestConfigureDatabaseAction:
|
||||
def test_configure_ip(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
agent.action_manager.actions["configure_database_client"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
@@ -36,7 +36,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
|
||||
assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
for i in range(client_1.shut_down_duration + 1):
|
||||
action = ("DONOTHING", {"node_id": 0})
|
||||
action = ("do_nothing", {"node_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
|
||||
assert client_1.operating_state == NodeOperatingState.BOOTING
|
||||
|
||||
for i in range(client_1.start_up_duration + 1):
|
||||
action = ("DONOTHING", {"node_id": 0})
|
||||
action = ("do_nothing", {"node_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -80,7 +80,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture:
|
||||
client_1.power_off()
|
||||
|
||||
for i in range(client_1.shut_down_duration + 1):
|
||||
action = ("DONOTHING", {"node_id": 0})
|
||||
action = ("do_nothing", {"node_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
|
||||
@@ -24,12 +24,12 @@ def test_rng_seed_set(create_env):
|
||||
env.reset(seed=3)
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
|
||||
|
||||
env.reset(seed=3)
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
|
||||
|
||||
assert a == b
|
||||
|
||||
@@ -40,11 +40,11 @@ def test_rng_seed_unset(create_env):
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
|
||||
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
|
||||
|
||||
assert a != b
|
||||
|
||||
@@ -91,7 +91,7 @@ def test_mask_contents_correct():
|
||||
assert mask[action_num]
|
||||
node_obj.operating_state = NodeOperatingState.ON
|
||||
|
||||
if act_type == "DONOTHING":
|
||||
if act_type == "do_nothing":
|
||||
assert mask[action_num]
|
||||
|
||||
if act_type == "NODE_SERVICE_DISABLE":
|
||||
|
||||
@@ -32,10 +32,10 @@ FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.
|
||||
|
||||
|
||||
def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
"""Test that the DoNothingAction can form a request and that it is accepted by the simulation."""
|
||||
"""Test that the do_nothingAction can form a request and that it is accepted by the simulation."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
action = ("do_nothing", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
|
||||
assert svc.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
# 2: Scan and check that the visible state is now correct
|
||||
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
|
||||
action = ("node_service_scan", {"type": "node_service_scan", "node_name": "server_1", "service_name": "DNSServer"})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert svc.health_state_actual == SoftwareHealthState.GOOD
|
||||
@@ -67,7 +67,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
|
||||
assert svc.health_state_visible == SoftwareHealthState.GOOD
|
||||
|
||||
# 4: Scan and check that the visible state is now correct
|
||||
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
|
||||
action = ("node_service_scan", {"type": "node_service_scan", "node_name": "server_1", "service_name": "DNSServer"})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert svc.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
@@ -88,7 +88,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
svc.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
|
||||
# 2: Apply a patch action
|
||||
action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0})
|
||||
action = ("node_service_fix", {"type": "node_service_fix", "node_name": "server_1", "service_name": "DNSServer"})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -96,7 +96,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
assert svc.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# 4: perform a few do-nothing steps and check that the service is now in the good state
|
||||
action = ("DONOTHING", {})
|
||||
action = ("do_nothing", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert svc.health_state_actual == SoftwareHealthState.GOOD
|
||||
@@ -121,18 +121,19 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
|
||||
|
||||
# 2: Add a rule to block client 1 from reaching server 2 on router
|
||||
action = (
|
||||
"ROUTER_ACL_ADDRULE",
|
||||
"router_acl_add_rule",
|
||||
{
|
||||
"type": "router_acl_add_rule",
|
||||
"target_router": "router",
|
||||
"position": 4, # 4th rule
|
||||
"permission": 2, # DENY
|
||||
"source_ip_id": 3, # 10.0.1.2 (client_1)
|
||||
"dest_ip_id": 6, # 10.0.2.3 (server_2)
|
||||
"dest_port_id": 1, # ALL
|
||||
"source_port_id": 1, # ALL
|
||||
"protocol_id": 1, # ALL
|
||||
"source_wildcard_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"position": 4,
|
||||
"permission": "DENY",
|
||||
"src_ip": "10.0.1.2",
|
||||
"src_wildcard": 0,
|
||||
"src_port": "HTTP",
|
||||
"dst_ip": "10.0.2.3",
|
||||
"dst_wildcard": 0,
|
||||
"dst_port": "HTTP",
|
||||
"protocol_name": "udp",
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -148,24 +149,27 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
|
||||
|
||||
# 4: Add a rule to block server_1 from reaching server_2 on router (this should not affect comms as they are on same subnet)
|
||||
action = (
|
||||
"ROUTER_ACL_ADDRULE",
|
||||
"router_acl_add_rule",
|
||||
{
|
||||
"type": "router_acl_add_rule",
|
||||
"target_router": "router",
|
||||
"position": 5, # 5th rule
|
||||
"permission": 2, # DENY
|
||||
"source_ip_id": 5, # 10.0.2.2 (server_1)
|
||||
"dest_ip_id": 6, # 10.0.2.3 (server_2)
|
||||
"dest_port_id": 1, # ALL
|
||||
"source_port_id": 1, # ALL
|
||||
"protocol_id": 1, # ALL
|
||||
"source_wildcard_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"permission": "DENY", # DENY
|
||||
"src_ip": "10.0.2.2", # 10.0.2.2 (server_1)
|
||||
"src_wildcard": 0,
|
||||
"source_port": "ALL", # ALL
|
||||
"dst_ip": "10.0.2.3", # 10.0.2.3 (server_2)
|
||||
"dst_wildcard": 0,
|
||||
"dst_port": "ALL", # ALL
|
||||
"protocol_name": "ALL", # ALL
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
print(agent.most_recent_action)
|
||||
game.step()
|
||||
|
||||
print(agent.most_recent_action)
|
||||
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
|
||||
print(router.acl.show())
|
||||
assert router.acl.num_rules == 6
|
||||
assert server_1.ping("10.0.2.3") # Can ping server_2
|
||||
|
||||
@@ -186,8 +190,9 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
|
||||
# 2: Remove rule that allows HTTP traffic across the network
|
||||
action = (
|
||||
"ROUTER_ACL_REMOVERULE",
|
||||
"router_acl_remove_rule",
|
||||
{
|
||||
"type": "router_acl_remove_rule",
|
||||
"target_router": "router",
|
||||
"position": 3, # 4th rule
|
||||
},
|
||||
@@ -219,10 +224,11 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
action = (
|
||||
"HOST_NIC_DISABLE",
|
||||
"host_nic_disable",
|
||||
{
|
||||
"node_id": 0, # client_1
|
||||
"nic_id": 0, # the only nic (eth-1)
|
||||
"type": "host_nic_disable",
|
||||
"node_name": "client_1", # client_1
|
||||
"nic_num": 1, # the only nic (eth-1)
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -250,10 +256,11 @@ def test_host_nic_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAg
|
||||
|
||||
# 2: Use action to enable nic
|
||||
action = (
|
||||
"HOST_NIC_ENABLE",
|
||||
"host_nic_enable",
|
||||
{
|
||||
"node_id": 0, # client_1
|
||||
"nic_id": 0, # the only nic (eth-1)
|
||||
"type": "host_nic_enable",
|
||||
"node_name": "client_1", # client_1
|
||||
"nic_num": 1, # the only nic (eth-1)
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -277,11 +284,12 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge
|
||||
|
||||
# 2: perform a scan and make sure nothing has changed
|
||||
action = (
|
||||
"NODE_FILE_SCAN",
|
||||
"node_file_scan",
|
||||
{
|
||||
"node_id": 0, # client_1,
|
||||
"folder_id": 0, # downloads,
|
||||
"file_id": 0, # cat.png
|
||||
"type": "node_file_scan",
|
||||
"node_name": "client_1", # client_1,
|
||||
"folder_name": "downloads", # downloads,
|
||||
"file_name": "cat.png", # cat.png
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -314,11 +322,12 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
|
||||
# 2: delete the file
|
||||
action = (
|
||||
"NODE_FILE_DELETE",
|
||||
"node_file_delete",
|
||||
{
|
||||
"node_id": 0, # client_1
|
||||
"folder_id": 0, # downloads
|
||||
"file_id": 0, # cat.png
|
||||
"type": "node_file_delete",
|
||||
"node_name": "client_1", # client_1
|
||||
"folder_name": "downloads", # downloads
|
||||
"file_name": "cat.png", # cat.png
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -334,14 +343,16 @@ def test_node_file_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
"""Test that a file is created."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
action = (
|
||||
"NODE_FILE_CREATE",
|
||||
"node_file_create",
|
||||
{
|
||||
"node_id": 0,
|
||||
"type": "node_file_create",
|
||||
"node_name": "client_1",
|
||||
"folder_name": "test",
|
||||
"file_name": "file.txt",
|
||||
"force": "False",
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -357,9 +368,10 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
|
||||
|
||||
action = (
|
||||
"NODE_FILE_CREATE",
|
||||
"node_file_create",
|
||||
{
|
||||
"node_id": 0,
|
||||
"type": "node_file_create",
|
||||
"node_name": "client_1",
|
||||
"folder_name": "test",
|
||||
"file_name": "file.txt",
|
||||
},
|
||||
@@ -370,9 +382,10 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
assert client_1.file_system.get_file(folder_name="test", file_name="file.txt").num_access == 0
|
||||
|
||||
action = (
|
||||
"NODE_FILE_ACCESS",
|
||||
"node_file_access",
|
||||
{
|
||||
"node_id": 0,
|
||||
"type": "node_file_access",
|
||||
"node_name": "client_1",
|
||||
"folder_name": "test",
|
||||
"file_name": "file.txt",
|
||||
},
|
||||
@@ -390,9 +403,10 @@ def test_node_folder_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
|
||||
|
||||
action = (
|
||||
"NODE_FOLDER_CREATE",
|
||||
"node_folder_create",
|
||||
{
|
||||
"node_id": 0,
|
||||
"type": "node_folder_create",
|
||||
"node_name": "client_1",
|
||||
"folder_name": "test",
|
||||
},
|
||||
)
|
||||
@@ -418,8 +432,9 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
action = (
|
||||
"NETWORK_PORT_DISABLE",
|
||||
"network_port_disable",
|
||||
{
|
||||
"type": "network_port_disable",
|
||||
"target_nodename": "router", # router
|
||||
"port_id": 1, # port 1
|
||||
},
|
||||
@@ -450,8 +465,9 @@ def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGa
|
||||
|
||||
# 2: Use action to enable port
|
||||
action = (
|
||||
"NETWORK_PORT_ENABLE",
|
||||
"network_port_enable",
|
||||
{
|
||||
"type": "network_port_enable",
|
||||
"target_nodename": "router", # router
|
||||
"port_id": 1, # port 1
|
||||
},
|
||||
@@ -480,7 +496,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
assert browser.health_state_visible == SoftwareHealthState.UNUSED
|
||||
|
||||
# 2: Scan and check that the visible state is now correct
|
||||
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
|
||||
action = (
|
||||
"node_application_scan",
|
||||
{"type": "node_application_scan", "node_name": "client_1", "application_name": "WebBrowser"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert browser.health_state_actual == SoftwareHealthState.GOOD
|
||||
@@ -491,7 +510,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
assert browser.health_state_visible == SoftwareHealthState.GOOD
|
||||
|
||||
# 4: Scan and check that the visible state is now correct
|
||||
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
|
||||
action = (
|
||||
"node_application_scan",
|
||||
{"type": "node_application_scan", "node_name": "client_1", "application_name": "WebBrowser"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert browser.health_state_actual == SoftwareHealthState.COMPROMISED
|
||||
@@ -512,7 +534,10 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
|
||||
browser.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
|
||||
# 2: Apply a fix action
|
||||
action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0})
|
||||
action = (
|
||||
"node_application_fix",
|
||||
{"type": "node_application_fix", "node_name": "client_1", "application_name": "WebBrowser"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -520,7 +545,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
|
||||
assert browser.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# 4: perform a few do-nothing steps and check that the application is now in the good state
|
||||
action = ("DONOTHING", {})
|
||||
action = ("do_nothing", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert browser.health_state_actual == SoftwareHealthState.GOOD
|
||||
@@ -538,7 +563,10 @@ def test_node_application_close_integration(game_and_agent: Tuple[PrimaiteGame,
|
||||
assert browser.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
# 2: Apply a close action
|
||||
action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0})
|
||||
action = (
|
||||
"node_application_close",
|
||||
{"type": "node_application_close", "node_name": "client_1", "application_name": "WebBrowser"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -549,7 +577,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
|
||||
"""Test that the NodeApplicationInstallAction and NodeApplicationRemoveAction can form a request and that
|
||||
it is accepted by the simulation.
|
||||
|
||||
When you initiate a install action, the Application will be installed and configured on the node.
|
||||
When you initiate an install action, the Application will be installed and configured on the node.
|
||||
The remove action will uninstall the application from the node."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
@@ -557,13 +585,19 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
|
||||
|
||||
assert client_1.software_manager.software.get("DoSBot") is None
|
||||
|
||||
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
|
||||
action = (
|
||||
"node_application_install",
|
||||
{"type": "node_application_install", "node_name": "client_1", "application_name": "DoSBot"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert client_1.software_manager.software.get("DoSBot") is not None
|
||||
|
||||
action = ("NODE_APPLICATION_REMOVE", {"node_id": 0, "application_name": "DoSBot"})
|
||||
action = (
|
||||
"node_application_remove",
|
||||
{"type": "node_application_remove", "node_name": "client_1", "application_name": "DoSBot"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
@@ -656,9 +690,9 @@ def test_firewall_acl_add_remove_rule_integration():
|
||||
assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
|
||||
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
|
||||
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
|
||||
assert firewall.external_outbound_acl.acl[1].dst_port is None
|
||||
assert firewall.external_outbound_acl.acl[1].src_port is None
|
||||
assert firewall.external_outbound_acl.acl[1].protocol is None
|
||||
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"]
|
||||
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"]
|
||||
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
|
||||
|
||||
env.step(12) # Remove ACL rule from External Outbound
|
||||
assert firewall.external_outbound_acl.num_rules == 1
|
||||
|
||||
@@ -18,12 +18,14 @@ from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import ControlledAgent
|
||||
|
||||
|
||||
def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
|
||||
"""Test that we get the right reward for failing to fetch a website."""
|
||||
# set up the scenario, configure the web browser to the correct url
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = WebpageUnavailablePenalty(node_hostname="client_1")
|
||||
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
|
||||
comp = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
@@ -31,7 +33,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
|
||||
# Check that before trying to fetch the webpage, the reward is 0.0
|
||||
agent.store_action(("DONOTHING", {}))
|
||||
agent.store_action(("do_nothing", {}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.0
|
||||
|
||||
@@ -53,7 +55,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
assert agent.reward_function.current_reward == -0.7
|
||||
|
||||
|
||||
def test_uc2_rewards(game_and_agent):
|
||||
def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
|
||||
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
@@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent):
|
||||
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
|
||||
)
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
|
||||
comp = GreenAdminDatabaseUnreachablePenalty(config=schema)
|
||||
|
||||
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
response = game.simulation.apply_request(request)
|
||||
@@ -139,17 +142,19 @@ def test_action_penalty_loads_from_config():
|
||||
act_penalty_obj = comp[0]
|
||||
if act_penalty_obj is None:
|
||||
pytest.fail("Action penalty reward component was not added to the agent from config.")
|
||||
assert act_penalty_obj.action_penalty == -0.75
|
||||
assert act_penalty_obj.do_nothing_penalty == 0.125
|
||||
assert act_penalty_obj.config.action_penalty == -0.75
|
||||
assert act_penalty_obj.config.do_nothing_penalty == 0.125
|
||||
|
||||
|
||||
def test_action_penalty():
|
||||
"""Test that the action penalty is correctly applied when agent performs any action"""
|
||||
|
||||
# Create an ActionPenalty Reward
|
||||
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
# Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
Penalty = ActionPenalty(config=schema)
|
||||
|
||||
# Assert that penalty is applied if action isn't DONOTHING
|
||||
# Assert that penalty is applied if action isn't do_nothing
|
||||
reward_value = Penalty.calculate(
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
@@ -163,12 +168,12 @@ def test_action_penalty():
|
||||
|
||||
assert reward_value == -0.75
|
||||
|
||||
# Assert that no penalty applied for a DONOTHING action
|
||||
# Assert that no penalty applied for a do_nothing action
|
||||
reward_value = Penalty.calculate(
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="DONOTHING",
|
||||
action="do_nothing",
|
||||
parameters={},
|
||||
request=["do_nothing"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
@@ -178,15 +183,16 @@ def test_action_penalty():
|
||||
assert reward_value == 0.125
|
||||
|
||||
|
||||
def test_action_penalty_e2e(game_and_agent):
|
||||
def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
|
||||
"""Test that we get the right reward for doing actions to fetch a website."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
comp = ActionPenalty(config=schema)
|
||||
|
||||
agent.reward_function.register_component(comp, 1.0)
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
action = ("do_nothing", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.125
|
||||
|
||||
@@ -3,9 +3,11 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
from primaite.game.agent.actions import ( # DoNothingAction,; NodeServiceDisableAction,; NodeServiceEnableAction,; NodeServicePauseAction,; NodeServiceRestartAction,; NodeServiceResumeAction,; NodeServiceScanAction,; NodeServiceStartAction,; NodeServiceStopAction,
|
||||
ActionManager,
|
||||
DoNothingAction,
|
||||
)
|
||||
from primaite.game.agent.actions.manager import DoNothingAction
|
||||
from primaite.game.agent.actions.service import (
|
||||
NodeServiceDisableAction,
|
||||
NodeServiceEnableAction,
|
||||
NodeServicePauseAction,
|
||||
@@ -18,7 +20,7 @@ from primaite.game.agent.actions import (
|
||||
|
||||
|
||||
def test_do_nothing_action_form_request():
|
||||
"""Test that the DoNothingAction can form a request and that it is correct."""
|
||||
"""Test that the do_nothingAction can form a request and that it is correct."""
|
||||
manager = Mock()
|
||||
|
||||
action = DoNothingAction(manager=manager)
|
||||
|
||||
@@ -28,9 +28,9 @@ def test_probabilistic_agent():
|
||||
|
||||
action_space_cfg = {
|
||||
"action_list": [
|
||||
{"type": "DONOTHING"},
|
||||
{"type": "NODE_APPLICATION_EXECUTE"},
|
||||
{"type": "NODE_FILE_DELETE"},
|
||||
{"type": "do_nothing"},
|
||||
{"type": "node_application_execute"},
|
||||
{"type": "node_file_delete"},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
@@ -48,9 +48,9 @@ def test_probabilistic_agent():
|
||||
"protocols": ["TCP", "UDP", "ICMP"],
|
||||
"ports": ["HTTP", "DNS", "ARP"],
|
||||
"act_map": {
|
||||
0: {"action": "DONOTHING", "options": {}},
|
||||
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
|
||||
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
|
||||
0: {"action": "do_nothing", "options": {}},
|
||||
1: {"action": "node_application_execute", "options": {"node_id": 0, "application_id": 0}},
|
||||
2: {"action": "node_file_delete", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
|
||||
},
|
||||
"options": {},
|
||||
}
|
||||
@@ -80,11 +80,11 @@ def test_probabilistic_agent():
|
||||
node_file_delete_count = 0
|
||||
for _ in range(N_TRIALS):
|
||||
a = pa.get_action(0)
|
||||
if a == ("DONOTHING", {}):
|
||||
if a == ("do_nothing", {}):
|
||||
do_nothing_count += 1
|
||||
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
|
||||
elif a == ("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}):
|
||||
node_application_execute_count += 1
|
||||
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
|
||||
elif a == ("node_file_delete", {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}):
|
||||
node_file_delete_count += 1
|
||||
else:
|
||||
raise AssertionError("Probabilistic agent produced an unexpected action.")
|
||||
|
||||
@@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse
|
||||
|
||||
class TestWebServer404PenaltySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=False)
|
||||
schema = WebServer404Penalty.ConfigSchema(
|
||||
node_hostname="computer",
|
||||
service_name="WebService",
|
||||
sticky=False,
|
||||
)
|
||||
reward = WebServer404Penalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
@@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=True)
|
||||
schema = WebServer404Penalty.ConfigSchema(
|
||||
node_hostname="computer",
|
||||
service_name="WebService",
|
||||
sticky=True,
|
||||
)
|
||||
reward = WebServer404Penalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
@@ -67,10 +77,11 @@ class TestWebServer404PenaltySticky:
|
||||
|
||||
class TestWebpageUnavailabilitySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=False)
|
||||
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False)
|
||||
reward = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
@@ -93,7 +104,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
@@ -127,10 +138,11 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=True)
|
||||
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True)
|
||||
reward = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
@@ -153,7 +165,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is sticky, it stays at 1.0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -188,10 +200,14 @@ class TestWebpageUnavailabilitySticky:
|
||||
|
||||
class TestGreenAdminDatabaseUnreachableSticky:
|
||||
def test_non_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
|
||||
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(
|
||||
node_hostname="computer",
|
||||
sticky=False,
|
||||
)
|
||||
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -212,9 +228,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
@@ -244,10 +259,14 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
|
||||
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(
|
||||
node_hostname="computer",
|
||||
sticky=True,
|
||||
)
|
||||
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -268,7 +287,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
|
||||
Reference in New Issue
Block a user