From a8fbb002e4299d6a34f3801a56b55b61f7fc674c Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 9 Dec 2024 09:54:35 +0000 Subject: [PATCH] #2912 - Updates following review, ACL rules now have validation for ConfigSchema fields --- src/primaite/game/agent/actions/acl.py | 42 ++++++++++++++++++----- src/primaite/game/agent/actions/config.py | 3 +- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index fb18d025..f129a82f 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -3,8 +3,12 @@ from __future__ import annotations from typing import List +from pydantic import field_validator + from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat +from primaite.utils.validation.ip_protocol import is_valid_protocol, protocol_validator +from primaite.utils.validation.port import is_valid_port __all__ = ( "RouterACLAddRuleAction", @@ -33,6 +37,35 @@ class ACLAddRuleAbstractAction(AbstractAction, identifier="acl_add_rule_abstract src_wildcard: int dst_wildcard: int + @field_validator( + src_port, + dst_port, + mode="before", + ) + @classmethod + def valid_port(cls, v: str) -> int: + """Check that inputs are valid.""" + return is_valid_port(v) + + @field_validator( + src_ip, + dst_ip, + mode="before", + ) + @classmethod + def valid_ip(cls, v: str) -> str: + """Check that a valid IP has been provided for src and dst.""" + return is_valid_protocol(v) + + @field_validator( + protocol_name, + mode="before", + ) + @classmethod + def is_valid_protocol(cls, v: str) -> bool: + """Check that we are using a valid protocol.""" + return protocol_validator(v) + class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"): """Base abstract class for ACL remove rule actions.""" @@ -70,7 +103,7 @@ class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_ad config.protocol_name, config.src_ip, config.src_wildcard, - config.source_port, + config.src_port, config.dst_ip, config.dst_wildcard, config.dst_port, @@ -109,13 +142,6 @@ class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_ac @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.protocol_name is None: - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - if config.src_ip == 0: - return ["do_nothing"] # invalid formulation - if config.src_port == 0: - return ["do_nothing"] # invalid configuration. - return [ "network", "node", diff --git a/src/primaite/game/agent/actions/config.py b/src/primaite/game/agent/actions/config.py index 319cd212..050e9b94 100644 --- a/src/primaite/game/agent/actions/config.py +++ b/src/primaite/game/agent/actions/config.py @@ -108,8 +108,7 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): @classmethod def form_request(self, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" - configuration = [] - return ["network", "node", config.node_name, "application", "C2Beacon", "configure", configuration] + return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config] class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):