#2912 - Corrections to some actions & fixing some linting. TODO: Action Manager errors
This commit is contained in:
@@ -3,10 +3,16 @@ from typing import Dict, List, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("RouterACLAddRuleAction", "RouterACLRemoveRuleAction", "FirewallACLAddRuleAction", "FirewallACLRemoveRuleAction")
|
||||
__all__ = (
|
||||
"RouterACLAddRuleAction",
|
||||
"RouterACLRemoveRuleAction",
|
||||
"FirewallACLAddRuleAction",
|
||||
"FirewallACLRemoveRuleAction",
|
||||
)
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
@@ -23,8 +29,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
dest_ip_id: int
|
||||
dest_wildcard_id: int
|
||||
dest_port_id: int
|
||||
protocol_id: int
|
||||
|
||||
protocol_name: str
|
||||
|
||||
class ACLRuleOptions(BaseModel):
|
||||
"""Validator for ACL_ADD_RULE options."""
|
||||
@@ -68,10 +73,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def form_request(
|
||||
cls,
|
||||
config: ConfigSchema
|
||||
) -> List[str]:
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
# Validate incoming data.
|
||||
parsed_options = RouterACLAddRuleAction.ACLRuleOptions(
|
||||
@@ -84,7 +86,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
dest_wildcard_id=config.dest_wildcard_id,
|
||||
source_port_id=config.source_port_id,
|
||||
dest_port_id=config.dest_port_id,
|
||||
protocol_id=config.protocol_id,
|
||||
protocol=config.protocol_name,
|
||||
)
|
||||
if parsed_options.permission == 1:
|
||||
permission_str = "PERMIT"
|
||||
@@ -96,40 +98,40 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
if parsed_options.protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
|
||||
protocol = cls.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if parsed_options.source_ip_id == 1:
|
||||
src_ip = "ALL"
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2)
|
||||
src_ip = cls.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id)
|
||||
src_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id)
|
||||
|
||||
if parsed_options.source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2)
|
||||
src_port = cls.manager.get_port_by_idx(parsed_options.source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if parsed_options.dest_ip_id == 1:
|
||||
dst_ip = "ALL"
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
|
||||
dst_ip = cls.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
|
||||
dst_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id)
|
||||
|
||||
if parsed_options.dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2)
|
||||
dst_port = cls.manager.get_port_by_idx(parsed_options.dest_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
target_router,
|
||||
config.target_router,
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission_str,
|
||||
@@ -140,7 +142,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
str(dst_ip),
|
||||
dst_wildcard,
|
||||
dst_port,
|
||||
position,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
|
||||
@@ -87,6 +87,8 @@ class ActionManager:
|
||||
# ip_list: List[str] = [], # to allow us to map an index to an ip address.
|
||||
# wildcard_list: List[str] = [], # to allow mapping from wildcard index to
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
@@ -116,27 +118,27 @@ class ActionManager:
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.node_names: List[str] = [n["node_name"] for n in nodes]
|
||||
# self.node_names: List[str] = [n["node_name"] for n in nodes]
|
||||
"""List of node names in this action space. The list order is the mapping between node index and node name."""
|
||||
self.application_names: List[List[str]] = []
|
||||
# self.application_names: List[List[str]] = []
|
||||
"""
|
||||
List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name.
|
||||
The first index corresponds to node id, the second index is the app id on that particular node.
|
||||
For instance, self.application_names[0][2] is the name of the third application on the first node.
|
||||
"""
|
||||
self.service_names: List[List[str]] = []
|
||||
# self.service_names: List[List[str]] = []
|
||||
"""
|
||||
List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name.
|
||||
The first index corresponds to node id, the second index is the service id on that particular node.
|
||||
For instance, self.service_names[0][2] is the name of the third service on the first node.
|
||||
"""
|
||||
self.folder_names: List[List[str]] = []
|
||||
# self.folder_names: List[List[str]] = []
|
||||
"""
|
||||
List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder
|
||||
name. The first index corresponds to node id, the second index is the folder id on that particular node.
|
||||
For instance, self.folder_names[0][2] is the name of the third folder on the first node.
|
||||
"""
|
||||
self.file_names: List[List[List[str]]] = []
|
||||
# self.file_names: List[List[List[str]]] = []
|
||||
"""
|
||||
List of files per folder per node. The list order gives the three-index mapping between
|
||||
(node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the
|
||||
@@ -203,7 +205,7 @@ class ActionManager:
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get("type")
|
||||
act_options = act_spec.get("options", {})
|
||||
self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
# self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import ClassVar
|
||||
from abc import abstractmethod
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
@@ -61,3 +62,97 @@ class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
|
||||
"""Configuration schema for NodeResetAction."""
|
||||
|
||||
verb: str = "reset"
|
||||
|
||||
|
||||
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
|
||||
"""Base class for NodeNMAP actions."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration Schema for NodeNMAP actions."""
|
||||
|
||||
target_ip_address: Union[str, List[str]]
|
||||
show: bool = False
|
||||
node_name: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
pass
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"application",
|
||||
"NMAP",
|
||||
"ping_scan",
|
||||
{"target_ip_address": config.target_ip_address, "show": config.show},
|
||||
]
|
||||
|
||||
|
||||
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
target_protocol: Optional[Union[str, List[str]]] = (None,)
|
||||
target_port: Optional[Union[str, List[str]]] = (None,)
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(
|
||||
cls,
|
||||
config: ConfigSchema,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"port_scan",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
target_protocol: Optional[Union[str, List[str]]] = (None,)
|
||||
target_port: Optional[Union[str, List[str]]] = (None,)
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(
|
||||
cls,
|
||||
config: ConfigSchema,
|
||||
) -> List[str]: # noqa
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"network_service_recon",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -18,6 +18,11 @@ __all__ = (
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
|
||||
"""Abstract Action for Node Service related actions.
|
||||
|
||||
Any actions which use node_name and service_name can inherit from this class.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
node_name: str
|
||||
service_name: str
|
||||
@@ -34,7 +39,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_
|
||||
"""Action which scans a service."""
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceScanAction"""
|
||||
"""Configuration Schema for NodeServiceScanAction."""
|
||||
|
||||
verb: str = "scan"
|
||||
|
||||
|
||||
@@ -19,8 +19,11 @@ class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstrac
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Abstract method. Should return the action formatted as a request which
|
||||
can be ingested by the PrimAITE simulation."""
|
||||
"""
|
||||
Abstract method for request forming.
|
||||
|
||||
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable\n"
|
||||
"from prettytable import PrettyTable\n",
|
||||
"UDP=\"UDP\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -195,7 +196,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -209,7 +210,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1783,7 +1783,7 @@
|
||||
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
|
||||
"from primaite.simulator.network.transmission.transport_layer import Port\n",
|
||||
"# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n",
|
||||
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n",
|
||||
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol[\"UDP\"], masquerade_port=Port[\"DNS\"])\n",
|
||||
"c2_beacon.establish()\n",
|
||||
"c2_beacon.show()"
|
||||
]
|
||||
@@ -1804,7 +1804,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -168,7 +168,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -182,7 +182,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -182,7 +182,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
|
||||
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port[\"HTTP\"], protocol = IPProtocol[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -249,7 +249,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -263,7 +263,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -537,7 +537,7 @@
|
||||
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
|
||||
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
|
||||
" action=ACLAction.DENY,\n",
|
||||
" protocol=IPProtocol["ICMP"],\n",
|
||||
" protocol=IPProtocol[\"ICMP\"],\n",
|
||||
" src_ip_address=\"192.168.10.22\",\n",
|
||||
" position=1\n",
|
||||
")"
|
||||
@@ -650,7 +650,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user