#2880: fix action shape for num_ports + test

This commit is contained in:
Czar Echavez
2024-09-13 11:11:58 +01:00
parent eb24d1270b
commit 9a2fb2a084
2 changed files with 23 additions and 2 deletions

View File

@@ -877,7 +877,7 @@ class FirewallACLRemoveRuleAction(AbstractAction):
"""Action which removes a rule from a firewall port's ACL.""" """Action which removes a rule from a firewall port's ACL."""
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
"""Init method for RouterACLRemoveRuleAction. """Init method for FirewallACLRemoveRuleAction.
:param manager: Reference to the ActionManager which created this action. :param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager :type manager: ActionManager
@@ -1524,7 +1524,7 @@ class ActionManager:
"num_nics": max_nics_per_node, "num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules, "num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols), "num_protocols": len(self.protocols),
"num_ports": len(self.protocols), "num_ports": len(self.ports),
"num_ips": len(self.ip_address_list), "num_ips": len(self.ip_address_list),
"max_acl_rules": max_acl_rules, "max_acl_rules": max_acl_rules,
"max_nics_per_node": max_nics_per_node, "max_nics_per_node": max_nics_per_node,

View File

@@ -0,0 +1,21 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from tests import TEST_ASSETS_ROOT
FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml"
def test_router_acl_add_rule_action_shape(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test to check ROUTER_ADD_ACL_RULE has the expected action shape."""
game, agent = game_and_agent
# assert that the shape of the actions is correct
router_acl_add_rule_action = agent.action_manager.actions.get("ROUTER_ACL_ADDRULE")
assert router_acl_add_rule_action.shape.get("source_ip_id") == len(agent.action_manager.ip_address_list)
assert router_acl_add_rule_action.shape.get("dest_ip_id") == len(agent.action_manager.ip_address_list)
assert router_acl_add_rule_action.shape.get("source_port_id") == len(agent.action_manager.ports)
assert router_acl_add_rule_action.shape.get("dest_port_id") == len(agent.action_manager.ports)
assert router_acl_add_rule_action.shape.get("protocol_id") == len(agent.action_manager.protocols)