#2912 - Corrections to some actions & fixing some linting. TODO: Action Manager errors

This commit is contained in:
Charlie Crane
2024-10-22 17:02:54 +01:00
parent 11357f87ca
commit 518b934e09
11 changed files with 149 additions and 41 deletions

View File

@@ -3,10 +3,16 @@ from typing import Dict, List, Literal
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
from primaite.interface.request import RequestFormat
__all__ = ("RouterACLAddRuleAction", "RouterACLRemoveRuleAction", "FirewallACLAddRuleAction", "FirewallACLRemoveRuleAction")
__all__ = (
"RouterACLAddRuleAction",
"RouterACLRemoveRuleAction",
"FirewallACLAddRuleAction",
"FirewallACLRemoveRuleAction",
)
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
@@ -23,8 +29,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
dest_ip_id: int
dest_wildcard_id: int
dest_port_id: int
protocol_id: int
protocol_name: str
class ACLRuleOptions(BaseModel):
"""Validator for ACL_ADD_RULE options."""
@@ -68,10 +73,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
return v
@classmethod
def form_request(
cls,
config: ConfigSchema
) -> List[str]:
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
# Validate incoming data.
parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
@@ -84,7 +86,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
dest_wildcard_id=config.dest_wildcard_id,
source_port_id=config.source_port_id,
dest_port_id=config.dest_port_id,
protocol_id=config.protocol_id,
protocol=config.protocol_name,
)
if parsed_options.permission == 1:
permission_str = "PERMIT"
@@ -96,40 +98,40 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
if parsed_options.protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
protocol = cls.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if parsed_options.source_ip_id == 1:
src_ip = "ALL"
else:
src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2)
src_ip = cls.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id)
src_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id)
if parsed_options.source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2)
src_port = cls.manager.get_port_by_idx(parsed_options.source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if parsed_options.dest_ip_id == 1:
dst_ip = "ALL"
else:
dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
dst_ip = cls.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
dst_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
if parsed_options.dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2)
dst_port = cls.manager.get_port_by_idx(parsed_options.dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
"network",
"node",
target_router,
config.target_router,
"acl",
"add_rule",
permission_str,
@@ -140,7 +142,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
str(dst_ip),
dst_wildcard,
dst_port,
position,
config.position,
]

View File

@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
from primaite.interface.request import RequestFormat
__all__ = (

View File

@@ -87,6 +87,8 @@ class ActionManager:
# ip_list: List[str] = [], # to allow us to map an index to an ip address.
# wildcard_list: List[str] = [], # to allow mapping from wildcard index to
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
*args,
**kwargs,
) -> None:
"""Init method for ActionManager.
@@ -116,27 +118,27 @@ class ActionManager:
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.node_names: List[str] = [n["node_name"] for n in nodes]
# self.node_names: List[str] = [n["node_name"] for n in nodes]
"""List of node names in this action space. The list order is the mapping between node index and node name."""
self.application_names: List[List[str]] = []
# self.application_names: List[List[str]] = []
"""
List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name.
The first index corresponds to node id, the second index is the app id on that particular node.
For instance, self.application_names[0][2] is the name of the third application on the first node.
"""
self.service_names: List[List[str]] = []
# self.service_names: List[List[str]] = []
"""
List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name.
The first index corresponds to node id, the second index is the service id on that particular node.
For instance, self.service_names[0][2] is the name of the third service on the first node.
"""
self.folder_names: List[List[str]] = []
# self.folder_names: List[List[str]] = []
"""
List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder
name. The first index corresponds to node id, the second index is the folder id on that particular node.
For instance, self.folder_names[0][2] is the name of the third folder on the first node.
"""
self.file_names: List[List[List[str]]] = []
# self.file_names: List[List[List[str]]] = []
"""
List of files per folder per node. The list order gives the three-index mapping between
(node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the
@@ -203,7 +205,7 @@ class ActionManager:
# and `options` is an optional dict of options to pass to the init method of the action class
act_type = act_spec.get("type")
act_options = act_spec.get("options", {})
self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options)
# self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options)
self.action_map: Dict[int, Tuple[str, Dict]] = {}
"""

View File

@@ -1,5 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import ClassVar
from abc import abstractmethod
from typing import ClassVar, List, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
@@ -61,3 +62,97 @@ class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
"""Configuration schema for NodeResetAction."""
verb: str = "reset"
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
"""Base class for NodeNMAP actions."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration Schema for NodeNMAP actions."""
target_ip_address: Union[str, List[str]]
show: bool = False
node_name: str
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
pass
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
pass
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"application",
"NMAP",
"ping_scan",
{"target_ip_address": config.target_ip_address, "show": config.show},
]
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
"""Action which performs an NMAP port scan."""
class ConfigSchema(AbstractAction.ConfigSchema):
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
show: Optional[bool] = (False,)
@classmethod
def form_request(
cls,
config: ConfigSchema,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"port_scan",
{
"target_ip_address": config.target_ip_address,
"target_port": config.target_port,
"target_protocol": config.target_protocol,
"show": config.show,
},
]
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
class ConfigSchema(AbstractAction.ConfigSchema):
target_protocol: Optional[Union[str, List[str]]] = (None,)
target_port: Optional[Union[str, List[str]]] = (None,)
show: Optional[bool] = (False,)
@classmethod
def form_request(
cls,
config: ConfigSchema,
) -> List[str]: # noqa
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"network_service_recon",
{
"target_ip_address": config.target_ip_address,
"target_port": config.target_port,
"target_protocol": config.target_protocol,
"show": config.show,
},
]

View File

@@ -18,6 +18,11 @@ __all__ = (
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
"""Abstract Action for Node Service related actions.
Any actions which use node_name and service_name can inherit from this class.
"""
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
service_name: str
@@ -34,7 +39,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_
"""Action which scans a service."""
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceScanAction"""
"""Configuration Schema for NodeServiceScanAction."""
verb: str = "scan"

View File

@@ -19,8 +19,11 @@ class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstrac
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Abstract method. Should return the action formatted as a request which
can be ingested by the PrimAITE simulation."""
"""
Abstract method for request forming.
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
"""
pass

View File

@@ -19,7 +19,8 @@
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable\n"
"from prettytable import PrettyTable\n",
"UDP=\"UDP\""
]
},
{
@@ -195,7 +196,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -209,7 +210,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -1783,7 +1783,7 @@
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n",
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n",
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol[\"UDP\"], masquerade_port=Port[\"DNS\"])\n",
"c2_beacon.establish()\n",
"c2_beacon.show()"
]
@@ -1804,7 +1804,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},

View File

@@ -168,7 +168,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -182,7 +182,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -182,7 +182,7 @@
"metadata": {},
"outputs": [],
"source": [
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port[\"HTTP\"], protocol = IPProtocol[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
]
},
{
@@ -249,7 +249,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -263,7 +263,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -537,7 +537,7 @@
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
" action=ACLAction.DENY,\n",
" protocol=IPProtocol["ICMP"],\n",
" protocol=IPProtocol[\"ICMP\"],\n",
" src_ip_address=\"192.168.10.22\",\n",
" position=1\n",
")"
@@ -650,7 +650,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},