#2682 Backport changes to core PrimAITE

This commit is contained in:
Marek Wolan
2024-06-25 11:04:52 +01:00
parent 4a81dc3b2c
commit 28dabad66b
52 changed files with 456 additions and 296 deletions

View File

@@ -11,9 +11,10 @@ AbstractAction. The ActionManager is responsible for:
"""
import itertools
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
from gymnasium import spaces
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from primaite import getLogger
@@ -321,12 +322,12 @@ class NodeFileCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
return ["do_nothing"]
return ["network", "node", node_name, "file_system", "create", "file", folder_name, file_name]
return ["network", "node", node_name, "file_system", "create", "file", folder_name, file_name, force]
class NodeFolderCreateAction(AbstractAction):
@@ -493,6 +494,47 @@ class NodeResetAction(NodeAbstractAction):
class RouterACLAddRuleAction(AbstractAction):
"""Action which adds a rule to a router's ACL."""
class ACLRuleOptions(BaseModel):
"""Validator for ACL_ADD_RULE options."""
target_router: str
"""On which router to add the rule, must be specified."""
position: int
"""At what position to add the rule, must be specified."""
permission: Literal[1, 2]
"""Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY."""
source_ip_id: int = Field(default=1, ge=1)
"""Rule source IP address. By default, all ip addresses."""
source_wildcard_id: int = Field(default=0, ge=0)
"""Rule source IP wildcard. By default, use the wildcard at index 0 from action manager."""
source_port_id: int = Field(default=1, ge=1)
"""Rule source port. By default, all source ports."""
dest_ip_id: int = Field(default=1, ge=1)
"""Rule destination IP address. By default, all ip addresses."""
dest_wildcard_id: int = Field(default=0, ge=0)
"""Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager."""
dest_port_id: int = Field(default=1, ge=1)
"""Rule destination port. By default, all destination ports."""
protocol_id: int = Field(default=1, ge=1)
"""Rule protocol. By default, all protocols."""
@field_validator(
"source_ip_id",
"source_port_id",
"source_wildcard_id",
"dest_ip_id",
"dest_port_id",
"dest_wildcard_id",
"protocol_id",
mode="before",
)
@classmethod
def not_none(cls, v: str, info: ValidationInfo) -> int:
"""If None is passed, use the default value instead."""
if v is None:
return cls.model_fields[info.field_name].default
return v
def __init__(
self,
manager: "ActionManager",
@@ -529,7 +571,7 @@ class RouterACLAddRuleAction(AbstractAction):
def form_request(
self,
target_router_nodename: str,
target_router: str,
position: int,
permission: int,
source_ip_id: int,
@@ -541,62 +583,63 @@ class RouterACLAddRuleAction(AbstractAction):
protocol_id: int,
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
# Validate incoming data.
parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
target_router=target_router,
position=position,
permission=permission,
source_ip_id=source_ip_id,
source_wildcard_id=source_wildcard_id,
dest_ip_id=dest_ip_id,
dest_wildcard_id=dest_wildcard_id,
source_port_id=source_port_id,
dest_port_id=dest_port_id,
protocol_id=protocol_id,
)
if parsed_options.permission == 1:
permission_str = "PERMIT"
elif permission == 2:
elif parsed_options.permission == 2:
permission_str = "DENY"
else:
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if protocol_id == 1:
if parsed_options.protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if source_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif source_ip_id == 1:
if parsed_options.source_ip_id == 1:
src_ip = "ALL"
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
if source_port_id == 0:
return ["do_nothing"] # invalid formulation
elif source_port_id == 1:
src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id)
if parsed_options.source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id - 2)
src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
if parsed_options.dest_ip_id == 1:
dst_ip = "ALL"
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
if dest_port_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_port_id == 1:
if parsed_options.dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
"network",
"node",
target_router_nodename,
target_router,
"acl",
"add_rule",
permission_str,
@@ -625,9 +668,9 @@ class RouterACLRemoveRuleAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
def form_request(self, target_router_nodename: str, position: int) -> List[str]:
def form_request(self, target_router: str, position: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", target_router_nodename, "acl", "remove_rule", position]
return ["network", "node", target_router, "acl", "remove_rule", position]
class FirewallACLAddRuleAction(AbstractAction):
@@ -877,7 +920,10 @@ class NodeNMAPPingScanAction(AbstractAction):
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, source_node: str, target_ip_address: Union[str, List[str]]) -> List[str]: # noqa
def form_request(
self, source_node: str, target_ip_address: Union[str, List[str]], show: Optional[bool] = False
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
@@ -885,7 +931,7 @@ class NodeNMAPPingScanAction(AbstractAction):
"application",
"NMAP",
"ping_scan",
{"target_ip_address": target_ip_address},
{"target_ip_address": target_ip_address, "show": show},
]
@@ -901,6 +947,7 @@ class NodeNMAPPortScanAction(AbstractAction):
target_ip_address: Union[str, List[str]],
target_protocol: Optional[Union[str, List[str]]] = None,
target_port: Optional[Union[str, List[str]]] = None,
show: Optional[bool] = False,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
@@ -910,7 +957,12 @@ class NodeNMAPPortScanAction(AbstractAction):
"application",
"NMAP",
"port_scan",
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
{
"target_ip_address": target_ip_address,
"target_port": target_port,
"target_protocol": target_protocol,
"show": show,
},
]
@@ -926,6 +978,7 @@ class NodeNetworkServiceReconAction(AbstractAction):
target_ip_address: Union[str, List[str]],
target_protocol: Optional[Union[str, List[str]]] = None,
target_port: Optional[Union[str, List[str]]] = None,
show: Optional[bool] = False,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
@@ -935,7 +988,12 @@ class NodeNetworkServiceReconAction(AbstractAction):
"application",
"NMAP",
"network_service_recon",
{"target_ip_address": target_ip_address, "target_port": target_port, "target_protocol": target_protocol},
{
"target_ip_address": target_ip_address,
"target_port": target_port,
"target_protocol": target_protocol,
"show": show,
},
]