#2912 - Pre-commit updates ahead of first draft PR.

This commit is contained in:
Charlie Crane
2024-11-13 10:40:51 +00:00
parent d757bd01f0
commit ed020f005f
9 changed files with 36 additions and 161 deletions

View File

@@ -29,7 +29,7 @@ New actions to be used within PrimAITE require:
class ConfigSchema(AbstractAction.ConfigSchema):
target_application: str
The ConfigSchema is used when the class is called to form the action.
The ConfigSchema is used when the class is called to form the action, within the `form_request` method, detailed below.
#. **Unique Identifier**:

View File

@@ -297,7 +297,7 @@ class NodeFolderAbstractAction(AbstractAction):
class NodeFolderScanAction(NodeFolderAbstractAction):
"""Action which scans a folder."""
def __init__(self, manager: "ActionManager", node_name: str, folder_name, **kwargs) -> None:
def __init__(self, manager: "ActionManager", node_name: str, folder_name: str, **kwargs) -> None:
super().__init__(manager, node_name=node_name, folder_name=folder_name, **kwargs)
self.verb: str = "scan"

View File

@@ -1,9 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, List, Literal
from typing import List, Literal
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
@@ -20,6 +18,9 @@ class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"):
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL abstract actions."""
src_ip: str
protocol_name: str
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
@@ -27,13 +28,11 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
target_router: str
position: int
permission: Literal[1, 2]
src_ip: str
source_wildcard_id: int
source_port: str
dst_ip: str
dst_wildcard: int
dst_port: int
protocol_name: str
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for RouterACLAddRuleAction."""
@@ -47,66 +46,10 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
dst_ip: str
dst_wildcard: int
dst_port: int
protocol_name: str
class ACLRuleOptions(BaseModel):
"""Validator for ACL_ADD_RULE options."""
target_router: str
"""On which router to add the rule, must be specified."""
position: int
"""At what position to add the rule, must be specified."""
permission: Literal[1, 2]
"""Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY."""
src_ip: str
"""Rule source IP address. By default, all ip addresses."""
source_wildcard_id: int = Field(default=0, ge=0)
"""Rule source IP wildcard. By default, use the wildcard at index 0 from action manager."""
source_port: int = Field(default=1, ge=1)
"""Rule source port. By default, all source ports."""
dst_ip_id: int = Field(default=1, ge=1)
"""Rule destination IP address. By default, all ip addresses."""
dst_wildcard: int = Field(default=0, ge=0)
"""Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager."""
dst_port_id: int = Field(default=1, ge=1)
"""Rule destination port. By default, all destination ports."""
protocol_name: str = "ALL"
"""Rule protocol. By default, all protocols."""
# @field_validator(
# "source_ip_id",
# "source_port_id",
# "source_wildcard_id",
# "dest_ip_id",
# "dest_port_id",
# "dest_wildcard_id",
# "protocol_name",
# mode="before",
# )
# @classmethod
# def not_none(cls, v: str, info: ValidationInfo) -> int:
# """If None is passed, use the default value instead."""
# if v is None:
# return cls.model_fields[info.field_name].default
# return v
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
# Validate incoming data.
# parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
# target_router=config.target_router,
# position=config.position,
# permission=config.permission,
# src_ip=config.src_ip,
# source_wildcard_id=config.source_wildcard_id,
# dst_ip_id=config.dst_ip,
# dst_wildcard=config.dst_wildcard,
# source_port_id=config.source_port_id,
# dest_port=config.dst_port,
# protocol=config.protocol_name,
# )
return [
"network",
"node",
@@ -118,7 +61,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
config.src_ip,
config.src_wildcard,
config.source_port,
str(config.dst_ip),
config.dst_ip,
config.dst_wildcard,
config.dst_port,
config.position,
@@ -160,63 +103,11 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
num_permissions: int = 3
permission: str
def __init__(
self,
manager: "ActionManager",
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
"""Init method for FirewallACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
:type num_ips: int
:param num_ports: Number of ports in the simulation.
:type num_ports: int
:param num_protocols: Number of protocols in the simulation.
:type num_protocols: int
"""
super().__init__(manager=manager)
num_permissions = 3
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_id": num_ips,
"dest_ip_id": num_ips,
"source_port_id": num_ports,
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif config.permission == 1:
permission_str = "PERMIT"
elif config.permission == 2:
permission_str = "DENY"
# else:
# _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if config.protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if config.protocol_id == 1:
protocol = "ALL"
else:
# protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
pass
if config.source_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif config.source_ip_id == 1:
@@ -235,26 +126,6 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
# subtract 2 to account for UNUSED=0, and ALL=1
pass
if config.dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif config.dest_ip_id == 1:
dst_ip = "ALL"
else:
# dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
pass
if config.dest_port_id == 0:
return ["do_nothing"] # invalid formulation
elif config.dest_port_id == 1:
dst_port = "ALL"
else:
# dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
# src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
# dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
pass
return [
"network",
"node",
@@ -264,13 +135,13 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
"acl",
"add_rule",
config.permission,
protocol,
config.protocol_name,
str(src_ip),
config.src_wildcard,
src_port,
str(dst_ip),
config.dst_ip,
config.dst_wildcard,
dst_port,
config.dst_port,
config.position,
]
@@ -278,29 +149,24 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"):
"""Action which removes a rule from a firewall port's ACL."""
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
"""Init method for RouterACLRemoveRuleAction.
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLRemoveRuleAction."""
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
position: int
@classmethod
def form_request(
cls, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int
) -> List[str]:
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
target_firewall_nodename,
firewall_port_name,
firewall_port_direction,
config.target_firewall_nodename,
config.firewall_port_name,
config.firewall_port_direction,
"acl",
"remove_rule",
position,
config.position,
]

View File

@@ -250,7 +250,6 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa
def form_request(self, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", config.model_config]

View File

@@ -178,7 +178,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_chec
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
"""Action which repairs a file"""
"""Action which repairs a file."""
verb: str = "repair"

View File

@@ -149,7 +149,8 @@ class ActionManager:
Since the agent uses a discrete action space which acts as a flattened version of the component-based
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
correspond to "node_service_scan" with ``node_name="server"`` and ``service_name="WebBrowser"``, action 2 can be "
correspond to "node_service_scan" with ``node_name="server"`` and
``service_name="WebBrowser"``, action 2 can be "
3. ``options``
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
These options are used to calculate the shape of the action space, and to provide additional information

View File

@@ -89,7 +89,11 @@ class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_acti
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
"""Action which performs an NMAP ping scan."""
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration schema for NodeNMAPPingScanAction."""
pass
@classmethod
@@ -110,6 +114,8 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_
"""Action which performs an NMAP port scan."""
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration Schema for NodeNMAPPortScanAction."""
source_node: str
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
@@ -141,6 +147,8 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_net
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NodeNetworkServiceReconAction."""
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
show: Optional[bool] = (False,)

View File

@@ -67,14 +67,14 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
pass
verb: str = "remote_logoff"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "service", "Terminal", "remote_logoff", config.remote_ip]
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):

View File

@@ -166,6 +166,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
game.step()
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
print(router.acl.show())
assert router.acl.num_rules == 6
assert server_1.ping("10.0.2.3") # Can ping server_2