Merge agents and actions branches + fix import / subclass errors

This commit is contained in:
Marek Wolan
2025-01-14 11:34:01 +00:00
57 changed files with 2297 additions and 2381 deletions

View File

@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty.
- Actions within PrimAITE are now extensible, allowing for plugin support.
## [3.3.0] - 2024-09-04

View File

@@ -0,0 +1,67 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
Extensible Actions
******************
Changes to Actions class Structure.
===================================
Actions within PrimAITE have been updated to inherit from a base class, AbstractAction, standardising their format and allowing for easier creation of custom actions. Actions now use a ``ConfigSchema`` to define the possible configuration variables, and use pydantic to enforce correct parameters are passed through.
Developing Custom Actions.
==========================
Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and contain 3 key items:
#. ConfigSchema class
#. Unique Identifier
#. `from_request` method.
ConfigSchema
############
The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file.
Unique Identifier
#################
When declaring a custom class, it must have a unique identifier string, that allows PrimAITE to generate the correct action when needed.
.. code:: Python
class CreateDirectoryAction(AbstractAction, identifier="node_folder_create")
config: CreateDirectoryAction.ConfigSchema
class ConfigSchema(AbstractAction.ConfigSchema):
verb: ClassVar[str] = "create"
node_name: str
directory_name: str
def form_request(cls, config: ConfigSchema) -> RequestFormat:
return ["network",
"node",
config.node_name,
"file_system",
config.verb,
"folder",
config.directory_name,
]
The above action would fail pydantic validation as the identifier "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`.
from_request method
###################
PrimAITE actions need to be have a `from_request` method, which can be passed to the `RequestManager` for processing. This allows the custom action to be actioned within the simulation environment.

View File

@@ -0,0 +1,57 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _about:
Extensible Rewards
******************
Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward
types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE
core repository).
Changes to reward class structure.
==================================
Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel).
Within the reward class there is a ConfigSchema class responsible for ensuring the config file data
is in the correct format. This also means there is little (if no) requirement for and `__init__`
method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`.
Each class requires an identifier string which is used by the ConfigSchema class to verify that it
hasn't previously been added to the registry.
Inheriting from `BaseModel` removes the need for an `__init__` method but means that object
attributes need to be passed by keyword.
To add a new reward class follow the example below. Note that the type attribute in the
`ConfigSchema` class should match the type used in the config file to define the reward.
.. code-block:: Python
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
config: "DatabaseFileIntegrity.ConfigSchema"
location_in_state: List[str] = [""]
reward: float = 0.0
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for DatabaseFileIntegrity."""
type: str = "DATABASE_FILE_INTEGRITY"
node_hostname: str
folder_name: str
file_name: str
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
pass
Changes to YAML file.
=====================
.. code:: YAML
There's no longer a need to provide a `dns_server` as an option in the simulation section
of the config file.

View File

@@ -55,50 +55,50 @@ agents:
action_space:
action_list:
- type: DONOTHING
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
- type: do_nothing
- type: node_shutdown
- type: node_startup
- type: host_nic_enable
- type: host_nic_enable
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_SHUTDOWN
action: node_shutdown
options:
node_id: 0
node_name: client_1
2:
action: NODE_SHUTDOWN
action: node_shutdown
options:
node_id: 1
node_name: server
3:
action: NODE_STARTUP
action: node_startup
options:
node_id: 0
node_name: client_1
4:
action: NODE_STARTUP
action: node_startup
options:
node_id: 1
node_name: server
5:
action: HOST_NIC_DISABLE
action: host_nic_disable
options:
node_id: 0
node_name: client_1
nic_id: 0
6:
action: HOST_NIC_DISABLE
action: host_nic_disable
options:
node_id: 1
node_name: server
nic_id: 0
7:
action: HOST_NIC_ENABLE
action: host_nic_enable
options:
node_id: 0
node_name: client_1
nic_id: 0
8:
action: HOST_NIC_ENABLE
action: host_nic_enable
options:
node_id: 1
node_name: server
nic_id: 0
options:
nodes:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from primaite.game.agent.actions import (
abstract,
acl,
application,
file,
folder,
host_nic,
manager,
network,
node,
service,
session,
software,
)
from primaite.game.agent.actions.manager import ActionManager
__all__ = (
"abstract",
"acl",
"application",
"software",
"file",
"folder",
"host_nic",
"manager",
"network",
"node",
"service",
"session",
"ActionManager",
)

View File

@@ -0,0 +1,36 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import ABC
from typing import Any, ClassVar, Dict, Optional, Type
from pydantic import BaseModel, ConfigDict
from primaite.interface.request import RequestFormat
class AbstractAction(BaseModel, ABC):
"""Base class for actions."""
config: "AbstractAction.ConfigSchema"
class ConfigSchema(BaseModel, ABC):
"""Base configuration schema for Actions."""
model_config = ConfigDict(extra="forbid")
type: str
_registry: ClassVar[Dict[str, Type[AbstractAction]]] = {}
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Cannot create new action under reserved name {identifier}")
cls._registry[identifier] = cls
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
pass

View File

@@ -0,0 +1,188 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import ABC
from typing import List
from pydantic import field_validator
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
from primaite.utils.validation.ip_protocol import protocol_validator
from primaite.utils.validation.ipv4_address import ipv4_validator, IPV4Address
from primaite.utils.validation.port import port_validator
__all__ = (
"RouterACLAddRuleAction",
"RouterACLRemoveRuleAction",
"FirewallACLAddRuleAction",
"FirewallACLRemoveRuleAction",
)
class ACLAddRuleAbstractAction(AbstractAction, ABC):
"""Base abstract class for ACL add rule actions."""
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL add rule abstract actions."""
src_ip: IPV4Address
protocol_name: str
permission: str
position: int
dst_ip: IPV4Address
src_port: int
dst_port: int
src_wildcard: int
dst_wildcard: int
@field_validator(
"src_port",
"dst_port",
mode="before",
)
@classmethod
def valid_port(cls, v: str) -> int:
"""Check that inputs are valid."""
return port_validator(v)
@field_validator(
"src_ip",
"dst_ip",
mode="before",
)
@classmethod
def valid_ip(cls, v: str) -> str:
"""Check that a valid IP has been provided for src and dst."""
return ipv4_validator(v)
@field_validator(
"protocol_name",
mode="before",
)
@classmethod
def is_valid_protocol(cls, v: str) -> bool:
"""Check that we are using a valid protocol."""
return protocol_validator(v)
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
"""Base abstract class for ACL remove rule actions."""
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL remove rule abstract actions."""
position: int
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
config: "RouterACLAddRuleAction.ConfigSchema"
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
"""Configuration Schema for RouterACLAddRuleAction."""
target_router: str
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.target_router,
"acl",
"add_rule",
config.permission,
config.protocol_name,
config.src_ip,
config.src_wildcard,
config.src_port,
config.dst_ip,
config.dst_wildcard,
config.dst_port,
config.position,
]
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
"""Action which removes a rule from a router's ACL."""
config: "RouterACLRemoveRuleAction.ConfigSchema"
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
"""Configuration schema for RouterACLRemoveRuleAction."""
target_router: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
"""Action which adds a rule to a firewall port's ACL."""
config: "FirewallACLAddRuleAction.ConfigSchema"
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLAddRuleAction."""
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.target_firewall_nodename,
config.firewall_port_name,
config.firewall_port_direction,
"acl",
"add_rule",
config.permission,
config.protocol_name,
config.src_ip,
config.src_wildcard,
config.src_port,
config.dst_ip,
config.dst_wildcard,
config.dst_port,
config.position,
]
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
"""Action which removes a rule from a firewall port's ACL."""
config: "FirewallACLRemoveRuleAction.ConfigSchema"
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLRemoveRuleAction."""
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.target_firewall_nodename,
config.firewall_port_name,
config.firewall_port_direction,
"acl",
"remove_rule",
config.position,
]

View File

@@ -0,0 +1,137 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.abstract import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeApplicationExecuteAction",
"NodeApplicationScanAction",
"NodeApplicationCloseAction",
"NodeApplicationFixAction",
"NodeApplicationInstallAction",
"NodeApplicationRemoveAction",
)
class NodeApplicationAbstractAction(AbstractAction, ABC):
"""
Base class for application actions.
Any action which applies to an application and uses node_id and application_id as its only two parameters can
inherit from this base class.
"""
config: "NodeApplicationAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for Node Application actions."""
node_name: str
application_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"application",
config.application_name,
config.verb,
]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
"""Action which executes an application."""
config: "NodeApplicationExecuteAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationExecuteAction."""
verb: str = "execute"
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
"""Action which scans an application."""
config: "NodeApplicationScanAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationScanAction."""
verb: str = "scan"
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
"""Action which closes an application."""
config: "NodeApplicationCloseAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationCloseAction."""
verb: str = "close"
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
"""Action which fixes an application."""
config: "NodeApplicationFixAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationFixAction."""
verb: str = "fix"
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
"""Action which installs an application."""
config: "NodeApplicationInstallAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationInstallAction."""
verb: str = "install"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"software_manager",
"application",
config.verb,
config.application_name,
]
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
"""Action which removes/uninstalls an application."""
config: "NodeApplicationRemoveAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationRemoveAction."""
verb: str = "uninstall"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"software_manager",
"application",
config.verb,
config.application_name,
]

View File

@@ -0,0 +1,189 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeFileCreateAction",
"NodeFileScanAction",
"NodeFileDeleteAction",
"NodeFileRestoreAction",
"NodeFileCorruptAction",
"NodeFileAccessAction",
"NodeFileCheckhashAction",
"NodeFileRepairAction",
)
class NodeFileAbstractAction(AbstractAction, ABC):
"""Abstract base class for file actions.
Any action which applies to a file and uses node_name, folder_name, and file_name as its
only three parameters can inherit from this base class.
"""
config: "NodeFileAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for NodeFileAbstractAction."""
node_name: str
folder_name: str
file_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
"folder",
config.folder_name,
"file",
config.file_name,
config.verb,
]
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
"""Action which creates a new file in a given folder."""
config: "NodeFileCreateAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCreateAction."""
verb: ClassVar[str] = "create"
force: bool = False
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
"file",
config.folder_name,
config.file_name,
config.verb,
]
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
"""Action which scans a file."""
config: "NodeFileScanAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileScanAction."""
verb: ClassVar[str] = "scan"
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
"""Action which deletes a file."""
config: "NodeFileDeleteAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileDeleteAction."""
verb: ClassVar[str] = "delete"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
"file",
config.folder_name,
config.file_name,
]
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
"""Action which restores a file."""
config: "NodeFileRestoreAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileRestoreAction."""
verb: ClassVar[str] = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
"""Action which corrupts a file."""
config: "NodeFileCorruptAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCorruptAction."""
verb: ClassVar[str] = "corrupt"
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
"""Action which increases a file's access count."""
config: "NodeFileAccessAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileAccessAction."""
verb: ClassVar[str] = "access"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
config.folder_name,
config.file_name,
]
class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"):
"""Action which checks the hash of a file."""
config: "NodeFileCheckhashAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCheckhashAction."""
verb: ClassVar[str] = "checkhash"
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
"""Action which repairs a file."""
config: "NodeFileRepairAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration Schema for NodeFileRepairAction."""
verb: ClassVar[str] = "repair"

View File

@@ -0,0 +1,117 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeFolderScanAction",
"NodeFolderCheckhashAction",
"NodeFolderRepairAction",
"NodeFolderRestoreAction",
"NodeFolderCreateAction",
)
class NodeFolderAbstractAction(AbstractAction, ABC):
"""
Base class for folder actions.
Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from
this base class.
"""
config: "NodeFolderAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NodeFolder actions."""
node_name: str
folder_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
"folder",
config.folder_name,
config.verb,
]
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
"""Action which scans a folder."""
config: "NodeFolderScanAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderScanAction."""
verb: ClassVar[str] = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
"""Action which checks the hash of a folder."""
config: "NodeFolderCheckhashAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderCheckhashAction."""
verb: ClassVar[str] = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
"""Action which repairs a folder."""
config: "NodeFolderRepairAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderRepairAction."""
verb: ClassVar[str] = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
"""Action which restores a folder."""
config: "NodeFolderRestoreAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderRestoreAction."""
verb: ClassVar[str] = "restore"
class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"):
"""Action which creates a new folder."""
config: "NodeFolderCreateAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderCreateAction."""
verb: ClassVar[str] = "create"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
"folder",
config.folder_name,
]

View File

@@ -0,0 +1,62 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("HostNICEnableAction", "HostNICDisableAction")
class HostNICAbstractAction(AbstractAction, ABC):
"""
Abstract base class for NIC actions.
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
class.
"""
config: "HostNICAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for HostNIC actions."""
node_name: str
nic_num: int
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.nic_num is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"network_interface",
config.nic_num,
config.verb,
]
class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
"""Action which enables a NIC."""
config: "HostNICEnableAction.ConfigSchema"
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
"""Configuration schema for HostNICEnableAction."""
verb: ClassVar[str] = "enable"
class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"):
"""Action which disables a NIC."""
config: "HostNICDisableAction.ConfigSchema"
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
"""Configuration schema for HostNICDisableAction."""
verb: ClassVar[str] = "disable"

View File

@@ -0,0 +1,138 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""yaml example.
agents:
- name: agent_1
action_space:
actions:
- do_nothing
- node_service_start
- node_service_stop
action_map:
"""
from __future__ import annotations
from typing import Dict, List, Optional, Tuple
from gymnasium import spaces
# from primaite.game.game import PrimaiteGame # TODO: Breaks things
from primaite.game.agent.actions.abstract import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("DoNothingAction", "ActionManager")
class DoNothingAction(AbstractAction, identifier="do_nothing"):
"""Do Nothing Action."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for do_nothingAction."""
type: str = "do_nothing"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
class ActionManager:
"""Class which manages the action space for an agent."""
def __init__(
self,
actions: List[Dict], # stores list of actions available to agent
nodes: List[Dict], # extra configuration for each node
act_map: Optional[
Dict[int, Dict]
] = None, # allows restricting set of possible actions - TODO: Refactor to be a list?
*args,
**kwargs,
) -> None:
"""Init method for ActionManager.
:param game: Reference to the game to which the agent belongs.
:type game: PrimaiteGame
:param actions: List of action specs which should be made available to the agent. The keys of each spec are:
'type' and 'options' for passing any options to the action class's init method
:type actions: List[dict]
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.actions: Dict[str, AbstractAction] = {}
for act_spec in actions:
act_type = act_spec.get("type")
self.actions[act_type] = AbstractAction._registry[act_type]
self.action_map: Dict[int, Tuple[str, Dict]] = {}
"""
Action mapping that converts an integer to a specific action and parameter choice.
For example :
{0: ("node_service_scan", {node_name:"client_1", service_name:"WebBrowser"})}
"""
if act_map is None:
# raise RuntimeError("Action map must be specified in the config file.")
pass
else:
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
# make sure all numbers between 0 and N are represented as dict keys in action map
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
def get_action(self, action: int) -> Tuple[str, Dict]:
"""Produce action in CAOS format."""
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_class = AbstractAction._registry[action_identifier]
config = act_class.ConfigSchema(**action_options)
return act_class.form_request(config=config)
@property
def space(self) -> spaces.Space:
"""Return the gymnasium action space for this agent."""
return spaces.Discrete(len(self.action_map))
@classmethod
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": # noqa: F821
"""
Construct an ActionManager from a config definition.
The action space config supports the following three sections:
1. ``action_list``
``action_list`` contains a list action components which need to be included in the action space.
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
which will be passed to the action class's __init__ method during initialisation.
2. ``action_map``
Since the agent uses a discrete action space which acts as a flattened version of the component-based
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
correspond to "node_service_scan" with ``node_name="server"`` and
``service_name="WebBrowser"``, action 2 can be "
3. ``options``
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
These options are used to calculate the shape of the action space, and to provide additional information
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
:param game: The Primaite Game to which the agent belongs.
:type game: PrimaiteGame
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
obj = cls(
actions=cfg["action_list"],
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
act_map=cfg.get("action_map"),
)
return obj

View File

@@ -0,0 +1,57 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"):
"""Base class for Network port actions."""
config: "NetworkPortAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NetworkPort actions."""
target_nodename: str
port_id: int
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.target_nodename is None or config.port_id is None:
return ["do_nothing"]
return [
"network",
"node",
config.target_nodename,
"network_interface",
config.port_id,
config.verb,
]
class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"):
"""Action which enables are port on a router or a firewall."""
config: "NetworkPortEnableAction.ConfigSchema"
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortEnableAction."""
verb: ClassVar[str] = "enable"
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
"""Action which disables are port on a router or a firewall."""
config: "NetworkPortDisableAction.ConfigSchema"
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortDisableAction."""
verb: ClassVar[str] = "disable"

View File

@@ -0,0 +1,195 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, List, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeOSScanAction",
"NodeShutdownAction",
"NodeStartupAction",
"NodeResetAction",
"NodeNMAPPingScanAction",
"NodeNMAPPortScanAction",
"NodeNetworkServiceReconAction",
)
class NodeAbstractAction(AbstractAction, identifier="node_abstract"):
"""
Abstract base class for node actions.
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
"""
config: "NodeAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for Node actions."""
node_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
print(config)
return ["network", "node", config.node_name, config.verb]
class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"):
"""Action which scans a node's OS."""
config: "NodeOSScanAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeOSScanAction."""
verb: ClassVar[str] = "scan"
class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"):
"""Action which shuts down a node."""
config: "NodeShutdownAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeShutdownAction."""
verb: ClassVar[str] = "shutdown"
class NodeStartupAction(NodeAbstractAction, identifier="node_startup"):
"""Action which starts up a node."""
config: "NodeStartupAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeStartupAction."""
verb: ClassVar[str] = "startup"
class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
"""Action which resets a node."""
config: "NodeResetAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeResetAction."""
verb: ClassVar[str] = "reset"
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
"""Base class for NodeNMAP actions."""
config: "NodeNMAPAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration Schema for NodeNMAP actions."""
target_ip_address: Union[str, List[str]]
show: bool = False
node_name: str
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
# NMAP action requests don't share a common format for their requests
# This is just a placeholder to ensure the method is defined.
pass
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
"""Action which performs an NMAP ping scan."""
config: "NodeNMAPPingScanAction.ConfigSchema"
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration schema for NodeNMAPPingScanAction."""
pass
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"application",
"NMAP",
"ping_scan",
{"target_ip_address": config.target_ip_address, "show": config.show},
]
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
"""Action which performs an NMAP port scan."""
config: "NodeNMAPPortScanAction.ConfigSchema"
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration Schema for NodeNMAPPortScanAction."""
source_node: str
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
show: Optional[bool] = (False,)
@classmethod
def form_request(
cls,
config: ConfigSchema,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"port_scan",
{
"target_ip_address": config.target_ip_address,
"target_port": config.target_port,
"target_protocol": config.target_protocol,
"show": config.show,
},
]
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
config: "NodeNetworkServiceReconAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NodeNetworkServiceReconAction."""
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
show: Optional[bool] = (False,)
@classmethod
def form_request(
cls,
config: ConfigSchema,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"network_service_recon",
{
"target_ip_address": config.target_ip_address,
"target_port": config.target_port,
"target_protocol": config.target_protocol,
"show": config.show,
},
]

View File

@@ -0,0 +1,135 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeServiceScanAction",
"NodeServiceStopAction",
"NodeServiceStartAction",
"NodeServicePauseAction",
"NodeServiceResumeAction",
"NodeServiceRestartAction",
"NodeServiceDisableAction",
"NodeServiceEnableAction",
"NodeServiceFixAction",
)
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
"""Abstract Action for Node Service related actions.
Any actions which use node_name and service_name can inherit from this class.
"""
config: "NodeServiceAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
service_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.node_name, "service", config.service_name, config.verb]
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
"""Action which scans a service."""
config: "NodeServiceScanAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceScanAction."""
verb: ClassVar[str] = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
"""Action which stops a service."""
config: "NodeServiceStopAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStopAction."""
verb: ClassVar[str] = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
"""Action which starts a service."""
config: "NodeServiceStartAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStartAction."""
verb: ClassVar[str] = "start"
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
"""Action which pauses a service."""
config: "NodeServicePauseAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServicePauseAction."""
verb: ClassVar[str] = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
"""Action which resumes a service."""
config: "NodeServiceResumeAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceResumeAction."""
verb: ClassVar[str] = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
"""Action which restarts a service."""
config: "NodeServiceRestartAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceRestartAction."""
verb: ClassVar[str] = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
"""Action which disables a service."""
config: "NodeServiceDisableAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceDisableAction."""
verb: ClassVar[str] = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
"""Action which enables a service."""
config: "NodeServiceEnableAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceEnableAction."""
verb: ClassVar[str] = "enable"
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
"""Action which fixes a service."""
config: "NodeServiceFixAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceFixAction."""
verb: ClassVar[str] = "fix"

View File

@@ -0,0 +1,108 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeSessionsRemoteLoginAction",
"NodeSessionsRemoteLogoutAction",
"NodeAccountChangePasswordAction",
)
class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"):
"""Base class for NodeSession actions."""
config: "NodeSessionAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NodeSessionAbstractActions."""
node_name: str
remote_ip: str
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""
Abstract method for request forming.
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
"""
pass
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
"""Action which performs a remote session login."""
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLoginAction."""
username: str
password: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"service",
"Terminal",
"ssh_to_remote",
config.username,
config.password,
config.remote_ip,
]
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"):
"""Action which performs a remote session logout."""
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
verb: str = "remote_logoff"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):
"""Action which changes the password for a user."""
config: "NodeAccountChangePasswordAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeAccountsChangePasswordAction."""
username: str
current_password: str
new_password: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"service",
"UserManager",
"change_password",
config.username,
config.current_password,
config.new_password,
]

View File

@@ -0,0 +1,238 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import List, Optional, Union
from pydantic import ConfigDict, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
from primaite.interface.request import RequestFormat
__all__ = (
"ConfigureRansomwareScriptAction",
"ConfigureDoSBotAction",
"ConfigureC2BeaconAction",
"NodeSendRemoteCommandAction",
"TerminalC2ServerAction",
"RansomwareLaunchC2ServerAction",
"ExfiltrationC2ServerAction",
"ConfigureDatabaseClientAction",
)
class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_ransomware_configure"):
"""Action which sets config parameters for a ransomware script on a node."""
config: "ConfigureRansomwareScriptAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for ConfigureRansomwareScriptAction."""
node_name: str
server_ip_address: Optional[str]
server_password: Optional[str]
payload: Optional[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"application",
"RansomwareScript",
"configure",
config.model_config,
]
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
"""Action which sets config parameters for a DoS bot on a node."""
config: "ConfigureDoSBotAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
model_config = ConfigDict(extra="forbid")
target_ip_address: Optional[str] = None
target_port: Optional[str] = None
payload: Optional[str] = None
repeat: Optional[bool] = None
port_scan_p_of_success: Optional[float] = None
dos_intensity: Optional[float] = None
max_sessions: Optional[int] = None
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
self.ConfigSchema.model_validate(config) # check that options adhere to schema
return ["network", "node", config.node_name, "application", "DoSBot", "configure", config]
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
"""Action which configures a C2 Beacon based on the parameters given."""
config: "ConfigureC2BeaconAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for ConfigureC2BeaconAction."""
node_name: str
c2_server_ip_address: str
keep_alive_frequency: int = Field(default=5, ge=1)
masquerade_protocol: str = Field(default="TCP")
masquerade_port: str = Field(default="HTTP")
@field_validator(
"c2_server_ip_address",
"keep_alive_frequency",
"masquerade_protocol",
"masquerade_port",
mode="before",
)
@classmethod
def not_none(cls, v: str, info: ValidationInfo) -> int:
"""If None is passed, use the default value instead."""
if v is None:
return cls.model_fields[info.field_name].default
return v
@classmethod
def form_request(self, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config]
class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):
"""Action which sends a terminal command to a remote node via SSH."""
config: "NodeSendRemoteCommandAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NodeSendRemoteCommandAction."""
node_name: str
remote_ip: str
command: RequestFormat
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"service",
"Terminal",
"send_remote_command",
config.remote_ip,
{"command": config.command},
]
class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
config: "TerminalC2ServerAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
commands: Union[List[RequestFormat], RequestFormat]
ip_address: Optional[str]
username: Optional[str]
password: Optional[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
command_model = {
"commands": config.commands,
"ip_address": config.ip_address,
"username": config.username,
"password": config.password,
}
return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model]
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
config: "RansomwareLaunchC2ServerAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for RansomwareLaunchC2ServerAction."""
node_name: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
# This action currently doesn't require any further configuration options.
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"):
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
config: "ExfiltrationC2ServerAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
username: Optional[str]
password: Optional[str]
target_ip_address: str
target_file_name: str
target_folder_name: str
exfiltration_folder_name: Optional[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
command_model = {
"target_file_name": config.target_file_name,
"target_folder_name": config.target_folder_name,
"exfiltration_folder_name": config.exfiltration_folder_name,
"target_ip_address": config.target_ip_address,
"username": config.username,
"password": config.password,
}
return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model]
class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"):
"""Action which sets config parameters for a database client on a node."""
config: "ConfigureDatabaseClientAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
model_config = ConfigDict(extra="forbid")
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", config.model_config]

View File

@@ -73,11 +73,13 @@ class AbstractAgent(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
type: str = "AbstractAgent"
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
super().__init_subclass__(**kwargs)
@property
def flatten_obs(self) -> bool:

View File

@@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs
__all__ = [
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
"ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",]
# fmt: on

View File

@@ -31,7 +31,7 @@ class AbstractObservation(ABC):
"""Initialise an observation. This method must be overwritten."""
self.default_observation: ObsType
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register an observation type.
@@ -40,6 +40,8 @@ class AbstractObservation(ABC):
:raises ValueError: When attempting to create a component with a name that is already in use.
"""
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Duplicate observation component type {identifier}")
cls._registry[identifier] = cls

View File

@@ -27,9 +27,10 @@ the structure:
service_ref: web_server_database_client
```
"""
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from pydantic import BaseModel
from typing_extensions import Never
from primaite import getLogger
@@ -42,25 +43,28 @@ _LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward:
class AbstractReward(BaseModel):
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
config: "AbstractReward.ConfigSchema"
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
class ConfigSchema(BaseModel, ABC):
"""Config schema for AbstractReward."""
type: str
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Duplicate reward {identifier}")
cls._registry[identifier] = cls
@classmethod
@abstractmethod
def from_config(cls, config: dict) -> "AbstractReward":
def from_config(cls, config: Dict) -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
@@ -68,11 +72,28 @@ class AbstractReward:
:return: The reward component.
:rtype: AbstractReward
"""
return cls()
if config["type"] not in cls._registry:
raise ValueError(f"Invalid reward type {config['type']}")
reward_class = cls._registry[config["type"]]
reward_obj = reward_class(config=reward_class.ConfigSchema(**config))
return reward_obj
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
class DummyReward(AbstractReward, identifier="DUMMY"):
"""Dummy reward function component which always returns 0.0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -86,41 +107,21 @@ class DummyReward(AbstractReward):
"""
return 0.0
@classmethod
def from_config(cls, config: dict) -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:return: The reward component.
:rtype: DummyReward
"""
return cls()
class DatabaseFileIntegrity(AbstractReward):
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
config: "DatabaseFileIntegrity.ConfigSchema"
location_in_state: List[str] = [""]
reward: float = 0.0
:param node_hostname: Hostname of the node which contains the database file.
:type node_hostname: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
:type file_name: str
"""
self.location_in_state = [
"network",
"nodes",
node_hostname,
"file_system",
"folders",
folder_name,
"files",
file_name,
]
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for DatabaseFileIntegrity."""
type: str = "DATABASE_FILE_INTEGRITY"
node_hostname: str
folder_name: str
file_name: str
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -132,6 +133,17 @@ class DatabaseFileIntegrity(AbstractReward):
:return: Reward value
:rtype: float
"""
self.location_in_state = [
"network",
"nodes",
self.config.node_hostname,
"file_system",
"folders",
self.config.folder_name,
"files",
self.config.file_name,
]
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
_LOGGER.debug(
@@ -148,44 +160,21 @@ class DatabaseFileIntegrity(AbstractReward):
else:
return 0
@classmethod
def from_config(cls, config: Dict) -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_hostname = config.get("node_hostname")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not (node_hostname and folder_name and file_name):
msg = f"{cls.__name__} could not be initialised with parameters {config}"
_LOGGER.error(msg)
raise ValueError(msg)
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
"""Initialise the reward component.
config: "WebServer404Penalty.ConfigSchema"
location_in_state: List[str] = [""]
reward: float = 0.0
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebServer404Penalty."""
type: str = "WEB_SERVER_404_PENALTY"
node_hostname: str
service_name: str
sticky: bool = True
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -197,6 +186,13 @@ class WebServer404Penalty(AbstractReward):
:return: Reward value
:rtype: float
"""
self.location_in_state = [
"network",
"nodes",
self.config.node_hostname,
"services",
self.config.service_name,
]
web_service_state = access_from_nested_dict(state, self.location_in_state)
# if webserver is no longer installed on the node, return 0
@@ -211,54 +207,27 @@ class WebServer404Penalty(AbstractReward):
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0
self.reward = 0.0
else: # skip calculating if sticky and no new codes. instead, reuse last step's value
pass
return self.reward
@classmethod
def from_config(cls, config: Dict) -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_hostname = config.get("node_hostname")
service_name = config.get("service_name")
if not (node_hostname and service_name):
msg = (
f"{cls.__name__} could not be initialised from config because node_name and service_ref were not "
"found in reward config."
)
_LOGGER.warning(msg)
raise ValueError(msg)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
class WebpageUnavailablePenalty(AbstractReward):
class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
config: "WebpageUnavailablePenalty.ConfigSchema"
reward: float = 0.0
location_in_state: List[str] = [""] # Calculate in __init__()?
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebpageUnavailablePenalty."""
type: str = "WEBPAGE_UNAVAILABLE_PENALTY"
node_hostname: str = ""
sticky: bool = True
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -274,6 +243,13 @@ class WebpageUnavailablePenalty(AbstractReward):
:return: Reward value
:rtype: float
"""
self.location_in_state = [
"network",
"nodes",
self.config.node_hostname,
"applications",
"WebBrowser",
]
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE:
@@ -283,14 +259,14 @@ class WebpageUnavailablePenalty(AbstractReward):
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
self.config.node_hostname,
"application",
"WebBrowser",
"execute",
]
# skip calculating if sticky and no new codes, reusing last step value
if not request_attempted and self.sticky:
if not request_attempted and self.config.sticky:
return self.reward
if last_action_response.response.status != "success":
@@ -298,7 +274,7 @@ class WebpageUnavailablePenalty(AbstractReward):
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0",
)
self.reward = 0.0
else:
@@ -312,37 +288,19 @@ class WebpageUnavailablePenalty(AbstractReward):
return self.reward
@classmethod
def from_config(cls, config: dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema"
reward: float = 0.0
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"
node_hostname: str
sticky: bool = True
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -362,7 +320,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
self.config.node_hostname,
"application",
"DatabaseClient",
"execute",
@@ -371,7 +329,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
if request_attempted: # if agent makes request, always recalculate fresh value
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
elif not self.sticky: # if no new request and not sticky, set reward to 0
elif not self.config.sticky: # if no new request and not sticky, set reward to 0
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
self.reward = 0.0
else: # if no new request and sticky, reuse reward value from last step
@@ -380,47 +338,30 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return self.reward
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class SharedReward(AbstractReward):
class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
config: "SharedReward.ConfigSchema"
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for SharedReward."""
type: str = "SHARED_REWARD"
agent_name: str
def default_callback(agent_name: str) -> Never:
"""
Initialise the shared reward.
Default callback to prevent calling this reward until it's properly initialised.
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_name: Optional[str]
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it.
@@ -432,36 +373,25 @@ class SharedReward(AbstractReward):
:return: Reward value
:rtype: float
"""
return self.callback(self.agent_name)
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":
"""
Build the SharedReward object from config.
:param config: Configuration dictionary
:type config: Dict
"""
agent_name = config.get("agent_name")
return cls(agent_name=agent_name)
return self.callback(self.config.agent_name)
class ActionPenalty(AbstractReward):
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
"""Apply a negative reward when taking any action except DONOTHING."""
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.
config: "ActionPenalty.ConfigSchema"
Reward or penalise agents for doing nothing or taking actions.
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for ActionPenalty.
:param action_penalty: Reward to give agents for taking any action except DONOTHING
:param action_penalty: Reward to give agents for taking any action except do_nothing
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
:param do_nothing_penalty: Reward to give agent for taking the do_nothing action
:type do_nothing_penalty: float
"""
self.action_penalty = action_penalty
self.do_nothing_penalty = do_nothing_penalty
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied.
@@ -473,33 +403,16 @@ class ActionPenalty(AbstractReward):
:return: Reward value
:rtype: float
"""
if last_action_response.action == "DONOTHING":
if last_action_response.action == "do_nothing":
return self.do_nothing_penalty
else:
return self.action_penalty
@classmethod
def from_config(cls, config: Dict) -> "ActionPenalty":
"""Build the ActionPenalty object from config."""
action_penalty = config.get("action_penalty", -1.0)
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
else:
return self.config.action_penalty
class RewardFunction:
"""Manages the reward function for the agent."""
rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
"ACTION_PENALTY": ActionPenalty,
}
"""List of reward class identifiers."""
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
@@ -534,7 +447,7 @@ class RewardFunction:
@classmethod
def from_config(cls, config: Dict) -> "RewardFunction":
"""Create a reward function from a config dictionary.
"""Create a reward function from a config dictionary and its related reward class.
:param config: dict of options for the reward manager's constructor
:type config: Dict
@@ -545,8 +458,11 @@ class RewardFunction:
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
# XXX: If options key is missing add key then add type key.
if "options" not in rew_component_cfg:
rew_component_cfg["options"] = {}
rew_component_cfg["options"]["type"] = rew_type
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}))
rew_instance = AbstractReward.from_config(rew_component_cfg["options"])
new.register_component(component=rew_instance, weight=weight)
return new

View File

@@ -370,7 +370,7 @@ class PrimaiteGame:
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_class)
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
new_service = new_node.software_manager.software[service_class.__name__]
# fixing duration for the service
@@ -580,7 +580,7 @@ class PrimaiteGame:
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
graph[name].add(comp.config.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward

View File

@@ -19,7 +19,7 @@
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable\n"
"from prettytable import PrettyTable"
]
},
{
@@ -195,7 +195,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -209,7 +209,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -1780,10 +1780,11 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
"from primaite.utils.validation.port import PORT_LOOKUP\n",
"\n",
"# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n",
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n",
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n",
"c2_beacon.establish()\n",
"c2_beacon.show()"
]
@@ -1804,7 +1805,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -1818,7 +1819,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -165,13 +165,13 @@
"\n",
"| node_id | node name |\n",
"|---------|------------------|\n",
"| 1 | domain_controller|\n",
"| 2 | web_server |\n",
"| 3 | database_server |\n",
"| 4 | backup_server |\n",
"| 5 | security_suite |\n",
"| 6 | client_1 |\n",
"| 7 | client_2 |\n",
"| 0 | domain_controller|\n",
"| 1 | web_server |\n",
"| 2 | database_server |\n",
"| 3 | backup_server |\n",
"| 4 | security_suite |\n",
"| 5 | client_1 |\n",
"| 6 | client_2 |\n",
"\n",
"Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n",
"\n",

View File

@@ -95,7 +95,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -109,7 +109,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -168,7 +168,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -182,7 +182,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -166,9 +166,10 @@
"from pathlib import Path\n",
"from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n",
"from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.file_system.file_system import FileSystem\n",
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
"from primaite.utils.validation.port import PORT_LOOKUP\n",
"\n",
"\n",
"# no applications exist yet so we will create our own.\n",
"class MSPaint(Application, identifier=\"MSPaint\"):\n",
@@ -182,7 +183,7 @@
"metadata": {},
"outputs": [],
"source": [
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=PORT_LOOKUP[\"HTTP\"], protocol = PROTOCOL_LOOKUP[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
]
},
{
@@ -249,7 +250,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -263,7 +264,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -532,12 +532,12 @@
},
"outputs": [],
"source": [
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
"\n",
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
" action=ACLAction.DENY,\n",
" protocol=IPProtocol["ICMP"],\n",
" protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n",
" src_ip_address=\"192.168.10.22\",\n",
" position=1\n",
")"
@@ -650,7 +650,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Literal, Type
from typing import Any, ClassVar, Dict, Literal, Optional, Type
from pydantic import BaseModel, model_validator
@@ -49,7 +49,7 @@ class NetworkNodeAdder(BaseModel):
_registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {}
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None:
"""
Register a network node adder class.
@@ -58,6 +58,8 @@ class NetworkNodeAdder(BaseModel):
:raises ValueError: When attempting to register a name that is already reserved.
"""
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Duplicate node adder {identifier}")
cls._registry[identifier] = cls

View File

@@ -1545,7 +1545,7 @@ class Node(SimComponent):
_identifier: ClassVar[str] = "unknown"
"""Identifier for this particular class, used for printing and logging. Each subclass redefines this."""
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register a node type.
@@ -1553,10 +1553,10 @@ class Node(SimComponent):
:type identifier: str
:raises ValueError: When attempting to register an node with a name that is already allocated.
"""
if identifier == "default":
super().__init_subclass__(**kwargs)
if identifier is None:
return
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls

View File

@@ -44,7 +44,7 @@ class Application(IOSoftware):
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register an application type.
@@ -52,9 +52,9 @@ class Application(IOSoftware):
:type identifier: str
:raises ValueError: When attempting to register an application with a name that is already allocated.
"""
if identifier == "default":
return
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls

View File

@@ -308,6 +308,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""
if not self._can_perform_action():
return None
if self.server_ip_address is None:
self.sys_log.warning(f"{self.name}: Database server IP address not provided.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None

View File

@@ -16,7 +16,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.port import Port, PORT_LOOKUP
if TYPE_CHECKING:
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog

View File

@@ -52,7 +52,7 @@ class Service(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register a hostnode type.
@@ -60,11 +60,11 @@ class Service(IOSoftware):
:type identifier: str
:raises ValueError: When attempting to register an hostnode with a name that is already allocated.
"""
if identifier == "default":
super().__init_subclass__(**kwargs)
if identifier is None:
return
# Enforce lowercase registry entries because it makes comparisons everywhere else much easier.
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new hostnode {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls

View File

@@ -31,7 +31,7 @@ def ipv4_validator(v: Any) -> IPv4Address:
IPV4Address: Final[Annotated] = Annotated[IPv4Address, BeforeValidator(ipv4_validator)]
"""
IPv4Address with with IPv4Address with with pre-validation and auto-conversion from str using ipv4_validator..
IPv4Address with pre-validation and auto-conversion from str using ipv4_validator..
This type is essentially an IPv4Address from the standard library's ipaddress module,
but with added validation logic. If you use this custom type, the ipv4_validator function

View File

@@ -205,8 +205,6 @@ simulation:
port_scan_p_of_success: 0.8
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
- type: DNSServer
options:
domain_mapping:

View File

@@ -33,7 +33,7 @@ agents:
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: do_nothing
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
@@ -47,7 +47,7 @@ agents:
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_APPLICATION_EXECUTE
@@ -82,7 +82,7 @@ agents:
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: do_nothing
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
@@ -96,7 +96,7 @@ agents:
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_APPLICATION_EXECUTE
@@ -132,7 +132,7 @@ agents:
action_space:
action_list:
- type: DONOTHING
- type: do_nothing
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
@@ -235,7 +235,7 @@ agents:
action_space:
action_list:
- type: DONOTHING
- type: do_nothing
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
@@ -265,7 +265,7 @@ agents:
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
# scan webapp service
1:

View File

@@ -96,155 +96,158 @@ agents:
action_space:
action_list:
- type: DONOTHING
- type: FIREWALL_ACL_ADDRULE
- type: FIREWALL_ACL_REMOVERULE
- type: NETWORK_PORT_DISABLE
- type: NETWORK_PORT_ENABLE
- type: do_nothing
- type: firewall_acl_add_rule
- type: firewall_acl_remove_rule
- type: network_port_disable
- type: network_port_enable
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: FIREWALL_ACL_ADDRULE
action: firewall_acl_add_rule
options:
type: firewall_acl_add_rule
target_firewall_nodename: firewall
firewall_port_name: internal
firewall_port_direction: inbound
position: 1
permission: 1
source_ip_id: 2 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
permission: PERMIT
src_ip: 192.168.0.10
dst_ip: 0.0.0.0
src_port: 80
dst_port: HTTP
protocol_name: TCP
src_wildcard: 0
dst_wildcard: 0
2:
action: FIREWALL_ACL_REMOVERULE
action: firewall_acl_remove_rule
options:
target_firewall_nodename: firewall
firewall_port_name: internal
firewall_port_direction: inbound
position: 1
3:
action: FIREWALL_ACL_ADDRULE
action: firewall_acl_add_rule
options:
target_firewall_nodename: firewall
firewall_port_name: internal
firewall_port_direction: outbound
position: 1
permission: 2
source_ip_id: 2 # client 1
dest_ip_id: 1 # ALL
source_port_id: 2
dest_port_id: 3
protocol_id: 2
permission: DENY
src_ip: 192.168.0.10 # client 1
dest_ip: ALL
src_port: ARP
dst_port: DNS
protocol_name: ICMP
source_wildcard_id: 0
dest_wildcard_id: 0
4:
action: FIREWALL_ACL_REMOVERULE
action: firewall_acl_remove_rule
options:
target_firewall_nodename: firewall
firewall_port_name: internal
firewall_port_direction: outbound
position: 1
5:
action: FIREWALL_ACL_ADDRULE
action: firewall_acl_add_rule
options:
target_firewall_nodename: firewall
firewall_port_name: dmz
firewall_port_direction: inbound
position: 1
permission: 2
source_ip_id: 3 # dmz_server
dest_ip_id: 2 # client_1
source_port_id: 4
dest_port_id: 4
protocol_id: 4
permission: DENY
src_ip: 192.168.10.10 # dmz_server
dest_ip: 192.168.0.10 # client_1
src_port: HTTP
dst_port: HTTP
protocol_name: UDP
source_wildcard_id: 0
dest_wildcard_id: 0
6:
action: FIREWALL_ACL_REMOVERULE
action: firewall_acl_remove_rule
options:
target_firewall_nodename: firewall
firewall_port_name: dmz
firewall_port_direction: inbound
position: 1
7:
action: FIREWALL_ACL_ADDRULE
action: firewall_acl_add_rule
options:
target_firewall_nodename: firewall
firewall_port_name: dmz
firewall_port_direction: outbound
position: 2
permission: 2
source_ip_id: 3 # dmz_server
dest_ip_id: 2 # client_1
source_port_id: 4
dest_port_id: 4
protocol_id: 3
permission: DENY
src_ip: 192.168.10.10 # dmz_server
dest_ip: 192.168.0.10 # client_1
src_port: HTTP
dst_port: HTTP
protocol_name: TCP
source_wildcard_id: 0
dest_wildcard_id: 0
8:
action: FIREWALL_ACL_REMOVERULE
action: firewall_acl_remove_rule
options:
target_firewall_nodename: firewall
firewall_port_name: dmz
firewall_port_direction: outbound
position: 2
9:
action: FIREWALL_ACL_ADDRULE
action: firewall_acl_add_rule
options:
target_firewall_nodename: firewall
firewall_port_name: external
firewall_port_direction: inbound
position: 10
permission: 2
source_ip_id: 4 # external_computer
dest_ip_id: 3 # dmz
source_port_id: 5
dest_port_id: 5
protocol_id: 2
permission: DENY
src_ip: 192.168.20.10 # external_computer
dest_ip: 192.168.10.10 # dmz
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
protocol_name: ICMP
source_wildcard_id: 0
dest_wildcard_id: 0
10:
action: FIREWALL_ACL_REMOVERULE
action: firewall_acl_remove_rule
options:
target_firewall_nodename: firewall
firewall_port_name: external
firewall_port_direction: inbound
position: 10
11:
action: FIREWALL_ACL_ADDRULE
action: firewall_acl_add_rule
options:
target_firewall_nodename: firewall
firewall_port_name: external
firewall_port_direction: outbound
position: 1
permission: 2
source_ip_id: 4 # external_computer
dest_ip_id: 2 # client_1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
permission: DENY
src_ip: 192.168.20.10 # external_computer
dest_ip: 192.168.0.10 # client_1
src_port: NONE
dst_port: NONE
protocol_name: none
source_wildcard_id: 0
dest_wildcard_id: 0
12:
action: FIREWALL_ACL_REMOVERULE
action: firewall_acl_remove_rule
options:
target_firewall_nodename: firewall
firewall_port_name: external
firewall_port_direction: outbound
position: 1
13:
action: NETWORK_PORT_DISABLE
action: network_port_disable
options:
type: network_port_disable
target_nodename: firewall
port_id: 3
14:
action: NETWORK_PORT_ENABLE
action: network_port_enable
options:
type: network_port_enable
target_nodename: firewall
port_id: 3
options:

View File

@@ -201,8 +201,6 @@ simulation:
port_scan_p_of_success: 0.8
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
- type: DNSServer
options:
domain_mapping:
@@ -233,8 +231,6 @@ simulation:
server_password: arcd
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
links:
- endpoint_a_hostname: switch_1

View File

@@ -34,15 +34,16 @@ agents:
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_NETWORK_SERVICE_RECON
- type: node_network_service_recon
action_map:
0:
action: NODE_NMAP_NETWORK_SERVICE_RECON
action: node_network_service_recon
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
target_port: 80
target_protocol: tcp
show: false
reward_function:
reward_components:

View File

@@ -34,13 +34,14 @@ agents:
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_PING_SCAN
- type: node_nmap_ping_scan
action_map:
0:
action: NODE_NMAP_PING_SCAN
action: node_nmap_ping_scan
options:
source_node: client_1
node_name: client_1
target_ip_address: 192.168.1.0/24
show: False
reward_function:
reward_components:

View File

@@ -34,19 +34,21 @@ agents:
max_services_per_node: 1
max_applications_per_node: 1
action_list:
- type: NODE_NMAP_PORT_SCAN
- type: node_nmap_port_scan
action_map:
0:
action: NODE_NMAP_PORT_SCAN
action: node_nmap_port_scan
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
target_protocol: tcp
target_port:
- 21
- 53
- 80
- 123
- 219
show: false
reward_function:
reward_components:

View File

@@ -210,7 +210,6 @@ simulation:
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
fix_duration: 3
- type: DNSServer
options:
@@ -251,8 +250,6 @@ simulation:
server_password: arcd
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
links:
- endpoint_a_hostname: switch_1

View File

@@ -415,85 +415,58 @@ def game_and_agent():
install_stuff_to_sim(sim)
actions = [
{"type": "DONOTHING"},
{"type": "NODE_SERVICE_SCAN"},
{"type": "NODE_SERVICE_STOP"},
{"type": "NODE_SERVICE_START"},
{"type": "NODE_SERVICE_PAUSE"},
{"type": "NODE_SERVICE_RESUME"},
{"type": "NODE_SERVICE_RESTART"},
{"type": "NODE_SERVICE_DISABLE"},
{"type": "NODE_SERVICE_ENABLE"},
{"type": "NODE_SERVICE_FIX"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_APPLICATION_SCAN"},
{"type": "NODE_APPLICATION_CLOSE"},
{"type": "NODE_APPLICATION_FIX"},
{"type": "NODE_APPLICATION_INSTALL"},
{"type": "NODE_APPLICATION_REMOVE"},
{"type": "NODE_FILE_CREATE"},
{"type": "NODE_FILE_SCAN"},
{"type": "NODE_FILE_CHECKHASH"},
{"type": "NODE_FILE_DELETE"},
{"type": "NODE_FILE_REPAIR"},
{"type": "NODE_FILE_RESTORE"},
{"type": "NODE_FILE_CORRUPT"},
{"type": "NODE_FILE_ACCESS"},
{"type": "NODE_FOLDER_CREATE"},
{"type": "NODE_FOLDER_SCAN"},
{"type": "NODE_FOLDER_CHECKHASH"},
{"type": "NODE_FOLDER_REPAIR"},
{"type": "NODE_FOLDER_RESTORE"},
{"type": "NODE_OS_SCAN"},
{"type": "NODE_SHUTDOWN"},
{"type": "NODE_STARTUP"},
{"type": "NODE_RESET"},
{"type": "ROUTER_ACL_ADDRULE"},
{"type": "ROUTER_ACL_REMOVERULE"},
{"type": "HOST_NIC_ENABLE"},
{"type": "HOST_NIC_DISABLE"},
{"type": "NETWORK_PORT_ENABLE"},
{"type": "NETWORK_PORT_DISABLE"},
{"type": "CONFIGURE_C2_BEACON"},
{"type": "C2_SERVER_RANSOMWARE_LAUNCH"},
{"type": "C2_SERVER_RANSOMWARE_CONFIGURE"},
{"type": "C2_SERVER_TERMINAL_COMMAND"},
{"type": "C2_SERVER_DATA_EXFILTRATE"},
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
{"type": "SSH_TO_REMOTE"},
{"type": "SESSIONS_REMOTE_LOGOFF"},
{"type": "NODE_SEND_REMOTE_COMMAND"},
{"type": "do_nothing"},
{"type": "node_service_scan"},
{"type": "node_service_stop"},
{"type": "node_service_start"},
{"type": "node_service_pause"},
{"type": "node_service_resume"},
{"type": "node_service_restart"},
{"type": "node_service_disable"},
{"type": "node_service_enable"},
{"type": "node_service_fix"},
{"type": "node_application_execute"},
{"type": "node_application_scan"},
{"type": "node_application_close"},
{"type": "node_application_fix"},
{"type": "node_application_install"},
{"type": "node_application_remove"},
{"type": "node_file_create"},
{"type": "node_file_scan"},
{"type": "node_file_checkhash"},
{"type": "node_file_delete"},
{"type": "node_file_repair"},
{"type": "node_file_restore"},
{"type": "node_file_corrupt"},
{"type": "node_file_access"},
{"type": "node_folder_create"},
{"type": "node_folder_scan"},
{"type": "node_folder_checkhash"},
{"type": "node_folder_repair"},
{"type": "node_folder_restore"},
{"type": "node_os_scan"},
{"type": "node_shutdown"},
{"type": "node_startup"},
{"type": "node_reset"},
{"type": "router_acl_add_rule"},
{"type": "router_acl_remove_rule"},
{"type": "host_nic_enable"},
{"type": "host_nic_disable"},
{"type": "network_port_enable"},
{"type": "network_port_disable"},
{"type": "configure_c2_beacon"},
{"type": "c2_server_ransomware_launch"},
{"type": "c2_server_ransomware_configure"},
{"type": "c2_server_terminal_command"},
{"type": "c2_server_data_exfiltrate"},
{"type": "node_account_change_password"},
{"type": "node_session_remote_login"},
{"type": "node_session_remote_logoff"},
{"type": "node_send_remote_command"},
]
action_space = ActionManager(
actions=actions, # ALL POSSIBLE ACTIONS
nodes=[
{
"node_name": "client_1",
"applications": [
{"application_name": "WebBrowser"},
{"application_name": "DoSBot"},
{"application_name": "C2Server"},
],
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
},
{
"node_name": "server_1",
"services": [{"service_name": "DNSServer"}],
"applications": [{"application_name": "C2Beacon"}],
},
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
{"node_name": "router"},
],
max_folders_per_node=2,
max_files_per_folder=2,
max_services_per_node=2,
max_applications_per_node=3,
max_nics_per_node=2,
max_acl_rules=10,
protocols=["TCP", "UDP", "ICMP"],
ports=["HTTP", "DNS", "ARP"],
ip_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = ObservationManager(NestedObservation(components={}))

View File

@@ -5,6 +5,7 @@ from primaite.config.load import get_extended_config_path
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from tests import TEST_ASSETS_ROOT
from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config
from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication
from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
@@ -13,11 +14,12 @@ from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
from tests.integration_tests.extensions.nodes.super_computer import SuperComputer
from tests.integration_tests.extensions.services.extended_service import ExtendedService
CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml"
def test_extended_example_config():
"""Test that the example config can be parsed properly."""
config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml")
game = load_config(config_path)
game = load_config(CONFIG_PATH)
network: Network = game.simulation.network
assert len(network.nodes) == 10 # 10 nodes in example network

View File

@@ -134,7 +134,7 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA
# Stepping a few timesteps to allow for the RansowmareScript to finish installing.
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
game.step()

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
import pytest
from pydantic import ValidationError
from primaite.game.agent.actions import (
from primaite.game.agent.actions.software import (
ConfigureDatabaseClientAction,
ConfigureDoSBotAction,
ConfigureRansomwareScriptAction,
@@ -35,10 +35,10 @@ class TestConfigureDatabaseAction:
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
action = (
"CONFIGURE_DATABASE_CLIENT",
"configure_database_client",
{
"node_id": 0,
"config": {
"node_name": "client_1",
"model_config": {
"server_ip_address": "192.168.1.99",
"server_password": "admin123",
},
@@ -53,7 +53,7 @@ class TestConfigureDatabaseAction:
def test_configure_ip(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
agent.action_manager.actions["configure_database_client"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")

View File

@@ -36,7 +36,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
action = ("DONOTHING", {"node_id": 0})
action = ("do_nothing", {"node_id": 0})
agent.store_action(action)
game.step()
@@ -50,7 +50,7 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.BOOTING
for i in range(client_1.start_up_duration + 1):
action = ("DONOTHING", {"node_id": 0})
action = ("do_nothing", {"node_id": 0})
agent.store_action(action)
game.step()
@@ -80,7 +80,7 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture:
client_1.power_off()
for i in range(client_1.shut_down_duration + 1):
action = ("DONOTHING", {"node_id": 0})
action = ("do_nothing", {"node_id": 0})
agent.store_action(action)
game.step()

View File

@@ -24,12 +24,12 @@ def test_rng_seed_set(create_env):
env.reset(seed=3)
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
env.reset(seed=3)
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
assert a == b
@@ -40,11 +40,11 @@ def test_rng_seed_unset(create_env):
env.reset()
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
env.reset()
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
assert a != b

View File

@@ -91,7 +91,7 @@ def test_mask_contents_correct():
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "DONOTHING":
if act_type == "do_nothing":
assert mask[action_num]
if act_type == "NODE_SERVICE_DISABLE":

View File

@@ -32,10 +32,10 @@ FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.
def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the DoNothingAction can form a request and that it is accepted by the simulation."""
"""Test that the do_nothingAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
@@ -56,7 +56,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
assert svc.health_state_visible == SoftwareHealthState.UNUSED
# 2: Scan and check that the visible state is now correct
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"type": "node_service_scan", "node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.GOOD
@@ -67,7 +67,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
assert svc.health_state_visible == SoftwareHealthState.GOOD
# 4: Scan and check that the visible state is now correct
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"type": "node_service_scan", "node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.COMPROMISED
@@ -88,7 +88,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
svc.health_state_actual = SoftwareHealthState.COMPROMISED
# 2: Apply a patch action
action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0})
action = ("node_service_fix", {"type": "node_service_fix", "node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
@@ -96,7 +96,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
assert svc.health_state_actual == SoftwareHealthState.FIXING
# 4: perform a few do-nothing steps and check that the service is now in the good state
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.GOOD
@@ -121,18 +121,19 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
# 2: Add a rule to block client 1 from reaching server 2 on router
action = (
"ROUTER_ACL_ADDRULE",
"router_acl_add_rule",
{
"type": "router_acl_add_rule",
"target_router": "router",
"position": 4, # 4th rule
"permission": 2, # DENY
"source_ip_id": 3, # 10.0.1.2 (client_1)
"dest_ip_id": 6, # 10.0.2.3 (server_2)
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
"position": 4,
"permission": "DENY",
"src_ip": "10.0.1.2",
"src_wildcard": 0,
"src_port": "HTTP",
"dst_ip": "10.0.2.3",
"dst_wildcard": 0,
"dst_port": "HTTP",
"protocol_name": "udp",
},
)
agent.store_action(action)
@@ -148,24 +149,27 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
# 4: Add a rule to block server_1 from reaching server_2 on router (this should not affect comms as they are on same subnet)
action = (
"ROUTER_ACL_ADDRULE",
"router_acl_add_rule",
{
"type": "router_acl_add_rule",
"target_router": "router",
"position": 5, # 5th rule
"permission": 2, # DENY
"source_ip_id": 5, # 10.0.2.2 (server_1)
"dest_ip_id": 6, # 10.0.2.3 (server_2)
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
"permission": "DENY", # DENY
"src_ip": "10.0.2.2", # 10.0.2.2 (server_1)
"src_wildcard": 0,
"source_port": "ALL", # ALL
"dst_ip": "10.0.2.3", # 10.0.2.3 (server_2)
"dst_wildcard": 0,
"dst_port": "ALL", # ALL
"protocol_name": "ALL", # ALL
},
)
agent.store_action(action)
print(agent.most_recent_action)
game.step()
print(agent.most_recent_action)
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
print(router.acl.show())
assert router.acl.num_rules == 6
assert server_1.ping("10.0.2.3") # Can ping server_2
@@ -186,8 +190,9 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
# 2: Remove rule that allows HTTP traffic across the network
action = (
"ROUTER_ACL_REMOVERULE",
"router_acl_remove_rule",
{
"type": "router_acl_remove_rule",
"target_router": "router",
"position": 3, # 4th rule
},
@@ -219,10 +224,11 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
# 2: Disable the NIC on client_1
action = (
"HOST_NIC_DISABLE",
"host_nic_disable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"type": "host_nic_disable",
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -250,10 +256,11 @@ def test_host_nic_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAg
# 2: Use action to enable nic
action = (
"HOST_NIC_ENABLE",
"host_nic_enable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"type": "host_nic_enable",
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -277,11 +284,12 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge
# 2: perform a scan and make sure nothing has changed
action = (
"NODE_FILE_SCAN",
"node_file_scan",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads,
"file_id": 0, # cat.png
"type": "node_file_scan",
"node_name": "client_1", # client_1,
"folder_name": "downloads", # downloads,
"file_name": "cat.png", # cat.png
},
)
agent.store_action(action)
@@ -314,11 +322,12 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
# 2: delete the file
action = (
"NODE_FILE_DELETE",
"node_file_delete",
{
"node_id": 0, # client_1
"folder_id": 0, # downloads
"file_id": 0, # cat.png
"type": "node_file_delete",
"node_name": "client_1", # client_1
"folder_name": "downloads", # downloads
"file_name": "cat.png", # cat.png
},
)
agent.store_action(action)
@@ -334,14 +343,16 @@ def test_node_file_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a file is created."""
game, agent = game_and_agent
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
client_1 = game.simulation.network.get_node_by_hostname("client_1")
action = (
"NODE_FILE_CREATE",
"node_file_create",
{
"node_id": 0,
"type": "node_file_create",
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
"force": "False",
},
)
agent.store_action(action)
@@ -357,9 +368,10 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
action = (
"NODE_FILE_CREATE",
"node_file_create",
{
"node_id": 0,
"type": "node_file_create",
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
},
@@ -370,9 +382,10 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
assert client_1.file_system.get_file(folder_name="test", file_name="file.txt").num_access == 0
action = (
"NODE_FILE_ACCESS",
"node_file_access",
{
"node_id": 0,
"type": "node_file_access",
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
},
@@ -390,9 +403,10 @@ def test_node_folder_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
action = (
"NODE_FOLDER_CREATE",
"node_folder_create",
{
"node_id": 0,
"type": "node_folder_create",
"node_name": "client_1",
"folder_name": "test",
},
)
@@ -418,8 +432,9 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
# 2: Disable the NIC on client_1
action = (
"NETWORK_PORT_DISABLE",
"network_port_disable",
{
"type": "network_port_disable",
"target_nodename": "router", # router
"port_id": 1, # port 1
},
@@ -450,8 +465,9 @@ def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGa
# 2: Use action to enable port
action = (
"NETWORK_PORT_ENABLE",
"network_port_enable",
{
"type": "network_port_enable",
"target_nodename": "router", # router
"port_id": 1, # port 1
},
@@ -480,7 +496,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
assert browser.health_state_visible == SoftwareHealthState.UNUSED
# 2: Scan and check that the visible state is now correct
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = (
"node_application_scan",
{"type": "node_application_scan", "node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.GOOD
@@ -491,7 +510,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
assert browser.health_state_visible == SoftwareHealthState.GOOD
# 4: Scan and check that the visible state is now correct
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = (
"node_application_scan",
{"type": "node_application_scan", "node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.COMPROMISED
@@ -512,7 +534,10 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
browser.health_state_actual = SoftwareHealthState.COMPROMISED
# 2: Apply a fix action
action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0})
action = (
"node_application_fix",
{"type": "node_application_fix", "node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
@@ -520,7 +545,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
assert browser.health_state_actual == SoftwareHealthState.FIXING
# 4: perform a few do-nothing steps and check that the application is now in the good state
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.GOOD
@@ -538,7 +563,10 @@ def test_node_application_close_integration(game_and_agent: Tuple[PrimaiteGame,
assert browser.operating_state == ApplicationOperatingState.RUNNING
# 2: Apply a close action
action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0})
action = (
"node_application_close",
{"type": "node_application_close", "node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
@@ -549,7 +577,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
"""Test that the NodeApplicationInstallAction and NodeApplicationRemoveAction can form a request and that
it is accepted by the simulation.
When you initiate a install action, the Application will be installed and configured on the node.
When you initiate an install action, the Application will be installed and configured on the node.
The remove action will uninstall the application from the node."""
game, agent = game_and_agent
@@ -557,13 +585,19 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
assert client_1.software_manager.software.get("DoSBot") is None
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
action = (
"node_application_install",
{"type": "node_application_install", "node_name": "client_1", "application_name": "DoSBot"},
)
agent.store_action(action)
game.step()
assert client_1.software_manager.software.get("DoSBot") is not None
action = ("NODE_APPLICATION_REMOVE", {"node_id": 0, "application_name": "DoSBot"})
action = (
"node_application_remove",
{"type": "node_application_remove", "node_name": "client_1", "application_name": "DoSBot"},
)
agent.store_action(action)
game.step()
@@ -656,9 +690,9 @@ def test_firewall_acl_add_remove_rule_integration():
assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
assert firewall.external_outbound_acl.acl[1].dst_port is None
assert firewall.external_outbound_acl.acl[1].src_port is None
assert firewall.external_outbound_acl.acl[1].protocol is None
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
env.step(12) # Remove ACL rule from External Outbound
assert firewall.external_outbound_acl.num_rules == 1

View File

@@ -18,12 +18,14 @@ from tests import TEST_ASSETS_ROOT
from tests.conftest import ControlledAgent
def test_WebpageUnavailablePenalty(game_and_agent):
def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that we get the right reward for failing to fetch a website."""
# set up the scenario, configure the web browser to the correct url
game, agent = game_and_agent
agent: ControlledAgent
comp = WebpageUnavailablePenalty(node_hostname="client_1")
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = WebpageUnavailablePenalty(config=schema)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
@@ -31,7 +33,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
agent.reward_function.register_component(comp, 0.7)
# Check that before trying to fetch the webpage, the reward is 0.0
agent.store_action(("DONOTHING", {}))
agent.store_action(("do_nothing", {}))
game.step()
assert agent.reward_function.current_reward == 0.0
@@ -53,7 +55,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
assert agent.reward_function.current_reward == -0.7
def test_uc2_rewards(game_and_agent):
def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
game, agent = game_and_agent
agent: ControlledAgent
@@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent):
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = GreenAdminDatabaseUnreachablePenalty(config=schema)
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
response = game.simulation.apply_request(request)
@@ -139,17 +142,19 @@ def test_action_penalty_loads_from_config():
act_penalty_obj = comp[0]
if act_penalty_obj is None:
pytest.fail("Action penalty reward component was not added to the agent from config.")
assert act_penalty_obj.action_penalty == -0.75
assert act_penalty_obj.do_nothing_penalty == 0.125
assert act_penalty_obj.config.action_penalty == -0.75
assert act_penalty_obj.config.do_nothing_penalty == 0.125
def test_action_penalty():
"""Test that the action penalty is correctly applied when agent performs any action"""
# Create an ActionPenalty Reward
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
# Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
Penalty = ActionPenalty(config=schema)
# Assert that penalty is applied if action isn't DONOTHING
# Assert that penalty is applied if action isn't do_nothing
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
@@ -163,12 +168,12 @@ def test_action_penalty():
assert reward_value == -0.75
# Assert that no penalty applied for a DONOTHING action
# Assert that no penalty applied for a do_nothing action
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="DONOTHING",
action="do_nothing",
parameters={},
request=["do_nothing"],
response=RequestResponse.from_bool(True),
@@ -178,15 +183,16 @@ def test_action_penalty():
assert reward_value == 0.125
def test_action_penalty_e2e(game_and_agent):
def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that we get the right reward for doing actions to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
comp = ActionPenalty(config=schema)
agent.reward_function.register_component(comp, 1.0)
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.125

View File

@@ -3,9 +3,11 @@ from unittest.mock import Mock
import pytest
from primaite.game.agent.actions import (
from primaite.game.agent.actions import ( # DoNothingAction,; NodeServiceDisableAction,; NodeServiceEnableAction,; NodeServicePauseAction,; NodeServiceRestartAction,; NodeServiceResumeAction,; NodeServiceScanAction,; NodeServiceStartAction,; NodeServiceStopAction,
ActionManager,
DoNothingAction,
)
from primaite.game.agent.actions.manager import DoNothingAction
from primaite.game.agent.actions.service import (
NodeServiceDisableAction,
NodeServiceEnableAction,
NodeServicePauseAction,
@@ -18,7 +20,7 @@ from primaite.game.agent.actions import (
def test_do_nothing_action_form_request():
"""Test that the DoNothingAction can form a request and that it is correct."""
"""Test that the do_nothingAction can form a request and that it is correct."""
manager = Mock()
action = DoNothingAction(manager=manager)

View File

@@ -28,9 +28,9 @@ def test_probabilistic_agent():
action_space_cfg = {
"action_list": [
{"type": "DONOTHING"},
{"type": "NODE_APPLICATION_EXECUTE"},
{"type": "NODE_FILE_DELETE"},
{"type": "do_nothing"},
{"type": "node_application_execute"},
{"type": "node_file_delete"},
],
"nodes": [
{
@@ -48,9 +48,9 @@ def test_probabilistic_agent():
"protocols": ["TCP", "UDP", "ICMP"],
"ports": ["HTTP", "DNS", "ARP"],
"act_map": {
0: {"action": "DONOTHING", "options": {}},
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
0: {"action": "do_nothing", "options": {}},
1: {"action": "node_application_execute", "options": {"node_id": 0, "application_id": 0}},
2: {"action": "node_file_delete", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
},
"options": {},
}
@@ -80,11 +80,11 @@ def test_probabilistic_agent():
node_file_delete_count = 0
for _ in range(N_TRIALS):
a = pa.get_action(0)
if a == ("DONOTHING", {}):
if a == ("do_nothing", {}):
do_nothing_count += 1
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
elif a == ("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}):
node_application_execute_count += 1
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
elif a == ("node_file_delete", {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}):
node_file_delete_count += 1
else:
raise AssertionError("Probabilistic agent produced an unexpected action.")

View File

@@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse
class TestWebServer404PenaltySticky:
def test_non_sticky(self):
reward = WebServer404Penalty("computer", "WebService", sticky=False)
schema = WebServer404Penalty.ConfigSchema(
node_hostname="computer",
service_name="WebService",
sticky=False,
)
reward = WebServer404Penalty(config=schema)
# no response codes yet, reward is 0
codes = []
@@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky:
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = WebServer404Penalty("computer", "WebService", sticky=True)
schema = WebServer404Penalty.ConfigSchema(
node_hostname="computer",
service_name="WebService",
sticky=True,
)
reward = WebServer404Penalty(config=schema)
# no response codes yet, reward is 0
codes = []
@@ -67,10 +77,11 @@ class TestWebServer404PenaltySticky:
class TestWebpageUnavailabilitySticky:
def test_non_sticky(self):
reward = WebpageUnavailablePenalty("computer", sticky=False)
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False)
reward = WebpageUnavailablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "do_nothing", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
@@ -93,7 +104,7 @@ class TestWebpageUnavailabilitySticky:
# THE IMPORTANT BIT
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
@@ -127,10 +138,11 @@ class TestWebpageUnavailabilitySticky:
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = WebpageUnavailablePenalty("computer", sticky=True)
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True)
reward = WebpageUnavailablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
@@ -153,7 +165,7 @@ class TestWebpageUnavailabilitySticky:
# THE IMPORTANT BIT
# agent did nothing, because reward is sticky, it stays at 1.0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
@@ -188,10 +200,14 @@ class TestWebpageUnavailabilitySticky:
class TestGreenAdminDatabaseUnreachableSticky:
def test_non_sticky(self):
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(
node_hostname="computer",
sticky=False,
)
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
@@ -212,9 +228,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
# THE IMPORTANT BIT
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
@@ -244,10 +259,14 @@ class TestGreenAdminDatabaseUnreachableSticky:
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(
node_hostname="computer",
sticky=True,
)
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
@@ -268,7 +287,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
# THE IMPORTANT BIT
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
action, params, request = "DO_NOTHING", {}, ["do_nothing"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(