#2912 - Pre-commit updates ahead of first draft PR.
This commit is contained in:
@@ -29,7 +29,7 @@ New actions to be used within PrimAITE require:
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
target_application: str
|
||||
|
||||
The ConfigSchema is used when the class is called to form the action.
|
||||
The ConfigSchema is used when the class is called to form the action, within the `form_request` method, detailed below.
|
||||
|
||||
|
||||
#. **Unique Identifier**:
|
||||
|
||||
@@ -297,7 +297,7 @@ class NodeFolderAbstractAction(AbstractAction):
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", node_name: str, folder_name, **kwargs) -> None:
|
||||
def __init__(self, manager: "ActionManager", node_name: str, folder_name: str, **kwargs) -> None:
|
||||
super().__init__(manager, node_name=node_name, folder_name=folder_name, **kwargs)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Literal
|
||||
from typing import List, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
@@ -20,6 +18,9 @@ class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL abstract actions."""
|
||||
|
||||
src_ip: str
|
||||
protocol_name: str
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
@@ -27,13 +28,11 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
target_router: str
|
||||
position: int
|
||||
permission: Literal[1, 2]
|
||||
src_ip: str
|
||||
source_wildcard_id: int
|
||||
source_port: str
|
||||
dst_ip: str
|
||||
dst_wildcard: int
|
||||
dst_port: int
|
||||
protocol_name: str
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for RouterACLAddRuleAction."""
|
||||
@@ -47,66 +46,10 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
dst_ip: str
|
||||
dst_wildcard: int
|
||||
dst_port: int
|
||||
protocol_name: str
|
||||
|
||||
class ACLRuleOptions(BaseModel):
|
||||
"""Validator for ACL_ADD_RULE options."""
|
||||
|
||||
target_router: str
|
||||
"""On which router to add the rule, must be specified."""
|
||||
position: int
|
||||
"""At what position to add the rule, must be specified."""
|
||||
permission: Literal[1, 2]
|
||||
"""Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY."""
|
||||
src_ip: str
|
||||
"""Rule source IP address. By default, all ip addresses."""
|
||||
source_wildcard_id: int = Field(default=0, ge=0)
|
||||
"""Rule source IP wildcard. By default, use the wildcard at index 0 from action manager."""
|
||||
source_port: int = Field(default=1, ge=1)
|
||||
"""Rule source port. By default, all source ports."""
|
||||
dst_ip_id: int = Field(default=1, ge=1)
|
||||
"""Rule destination IP address. By default, all ip addresses."""
|
||||
dst_wildcard: int = Field(default=0, ge=0)
|
||||
"""Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager."""
|
||||
dst_port_id: int = Field(default=1, ge=1)
|
||||
"""Rule destination port. By default, all destination ports."""
|
||||
protocol_name: str = "ALL"
|
||||
"""Rule protocol. By default, all protocols."""
|
||||
|
||||
# @field_validator(
|
||||
# "source_ip_id",
|
||||
# "source_port_id",
|
||||
# "source_wildcard_id",
|
||||
# "dest_ip_id",
|
||||
# "dest_port_id",
|
||||
# "dest_wildcard_id",
|
||||
# "protocol_name",
|
||||
# mode="before",
|
||||
# )
|
||||
# @classmethod
|
||||
# def not_none(cls, v: str, info: ValidationInfo) -> int:
|
||||
# """If None is passed, use the default value instead."""
|
||||
# if v is None:
|
||||
# return cls.model_fields[info.field_name].default
|
||||
# return v
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
# Validate incoming data.
|
||||
# parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
|
||||
# target_router=config.target_router,
|
||||
# position=config.position,
|
||||
# permission=config.permission,
|
||||
# src_ip=config.src_ip,
|
||||
# source_wildcard_id=config.source_wildcard_id,
|
||||
# dst_ip_id=config.dst_ip,
|
||||
# dst_wildcard=config.dst_wildcard,
|
||||
# source_port_id=config.source_port_id,
|
||||
# dest_port=config.dst_port,
|
||||
# protocol=config.protocol_name,
|
||||
# )
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
@@ -118,7 +61,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
config.src_ip,
|
||||
config.src_wildcard,
|
||||
config.source_port,
|
||||
str(config.dst_ip),
|
||||
config.dst_ip,
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
@@ -160,63 +103,11 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
num_permissions: int = 3
|
||||
permission: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
max_acl_rules: int,
|
||||
num_ips: int,
|
||||
num_ports: int,
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for FirewallACLAddRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
:param num_ips: Number of IP addresses in the simulation.
|
||||
:type num_ips: int
|
||||
:param num_ports: Number of ports in the simulation.
|
||||
:type num_ports: int
|
||||
:param num_protocols: Number of protocols in the simulation.
|
||||
:type num_protocols: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 3
|
||||
self.shape: Dict[str, int] = {
|
||||
"position": max_acl_rules,
|
||||
"permission": num_permissions,
|
||||
"source_ip_id": num_ips,
|
||||
"dest_ip_id": num_ips,
|
||||
"source_port_id": num_ports,
|
||||
"dest_port_id": num_ports,
|
||||
"protocol_id": num_protocols,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif config.permission == 1:
|
||||
permission_str = "PERMIT"
|
||||
elif config.permission == 2:
|
||||
permission_str = "DENY"
|
||||
# else:
|
||||
# _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if config.protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
if config.protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
# protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
pass
|
||||
|
||||
if config.source_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif config.source_ip_id == 1:
|
||||
@@ -235,26 +126,6 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
pass
|
||||
|
||||
if config.dest_ip_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif config.dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
else:
|
||||
# dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
pass
|
||||
|
||||
if config.dest_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif config.dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
# dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
# src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
|
||||
# dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
|
||||
pass
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
@@ -264,13 +135,13 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
protocol,
|
||||
config.protocol_name,
|
||||
str(src_ip),
|
||||
config.src_wildcard,
|
||||
src_port,
|
||||
str(dst_ip),
|
||||
config.dst_ip,
|
||||
config.dst_wildcard,
|
||||
dst_port,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
@@ -278,29 +149,24 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"):
|
||||
"""Action which removes a rule from a firewall port's ACL."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for RouterACLRemoveRuleAction.
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLRemoveRuleAction."""
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
position: int
|
||||
|
||||
@classmethod
|
||||
def form_request(
|
||||
cls, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int
|
||||
) -> List[str]:
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
target_firewall_nodename,
|
||||
firewall_port_name,
|
||||
firewall_port_direction,
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"remove_rule",
|
||||
position,
|
||||
config.position,
|
||||
]
|
||||
|
||||
@@ -250,7 +250,6 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa
|
||||
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", config.model_config]
|
||||
|
||||
@@ -178,7 +178,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_chec
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
|
||||
"""Action which repairs a file"""
|
||||
"""Action which repairs a file."""
|
||||
|
||||
verb: str = "repair"
|
||||
|
||||
|
||||
@@ -149,7 +149,8 @@ class ActionManager:
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "node_service_scan" with ``node_name="server"`` and ``service_name="WebBrowser"``, action 2 can be "
|
||||
correspond to "node_service_scan" with ``node_name="server"`` and
|
||||
``service_name="WebBrowser"``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
|
||||
@@ -89,7 +89,11 @@ class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_acti
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
|
||||
"""Action which performs an NMAP ping scan."""
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNMAPPingScanAction."""
|
||||
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@@ -110,6 +114,8 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeNMAPPortScanAction."""
|
||||
|
||||
source_node: str
|
||||
target_protocol: Optional[Union[str, List[str]]] = (None,)
|
||||
target_port: Optional[Union[str, List[str]]] = (None,)
|
||||
@@ -141,6 +147,8 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_net
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNetworkServiceReconAction."""
|
||||
|
||||
target_protocol: Optional[Union[str, List[str]]] = (None,)
|
||||
target_port: Optional[Union[str, List[str]]] = (None,)
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@@ -67,14 +67,14 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
|
||||
|
||||
pass
|
||||
verb: str = "remote_logoff"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", config.node_name, "service", "Terminal", "remote_logoff", config.remote_ip]
|
||||
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
|
||||
|
||||
|
||||
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):
|
||||
|
||||
@@ -166,6 +166,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
|
||||
game.step()
|
||||
|
||||
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
|
||||
print(router.acl.show())
|
||||
assert router.acl.num_rules == 6
|
||||
assert server_1.ping("10.0.2.3") # Can ping server_2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user