#2912 - Updates so that all tests within test_actions.py pass
This commit is contained in:
@@ -175,10 +175,6 @@ class NodeApplicationInstallAction(AbstractAction):
|
||||
class ConfigureDatabaseClientAction(AbstractAction):
|
||||
"""Action which sets config parameters for a database client on a node."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
class _Opts(BaseModel):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
@@ -204,7 +200,6 @@ class ConfigureRansomwareScriptAction(AbstractAction):
|
||||
class _Opts(BaseModel, AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this option."""
|
||||
|
||||
node_name: str
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
@@ -43,6 +43,5 @@ class AbstractAction(BaseModel):
|
||||
def from_config(cls, config: Dict) -> "AbstractAction":
|
||||
"""Create an action component from a config dictionary."""
|
||||
for attribute, value in config.items():
|
||||
if not hasattr(cls.ConfigSchema, attribute):
|
||||
setattr(cls.ConfigSchema, attribute, value)
|
||||
setattr(cls.ConfigSchema, attribute, value)
|
||||
return cls
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import List, Literal
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
@@ -92,6 +92,12 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
num_protocols: int
|
||||
num_permissions: int = 3
|
||||
permission: str
|
||||
target_firewall_nodename: str
|
||||
src_ip: str
|
||||
dst_ip: str
|
||||
dst_wildcard: str
|
||||
src_port: Union[int| None]
|
||||
dst_port: Union[int | None]
|
||||
|
||||
class ConfigSchema(ACLAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLAddRuleAction."""
|
||||
@@ -102,29 +108,22 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
num_protocols: int
|
||||
num_permissions: int = 3
|
||||
permission: str
|
||||
target_firewall_nodename: str
|
||||
src_ip: str
|
||||
dst_ip: str
|
||||
dst_wildcard: str
|
||||
src_port: Union[int| None]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.protocol_id == 0:
|
||||
if config.protocol_name == None:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
if config.source_ip_id == 0:
|
||||
if config.src_ip == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif config.source_ip_id == 1:
|
||||
src_ip = "ALL"
|
||||
else:
|
||||
# src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
pass
|
||||
if config.src_port == 0:
|
||||
return ["do_nothing"] # invalid configuration.
|
||||
|
||||
if config.source_port_id == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
elif config.source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
# src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
pass
|
||||
|
||||
return [
|
||||
"network",
|
||||
@@ -136,9 +135,9 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
str(src_ip),
|
||||
config.src_ip,
|
||||
config.src_wildcard,
|
||||
src_port,
|
||||
config.src_port,
|
||||
config.dst_ip,
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
|
||||
@@ -37,7 +37,7 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema
|
||||
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config]
|
||||
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config.model_config]
|
||||
|
||||
|
||||
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
|
||||
|
||||
@@ -112,14 +112,14 @@ agents:
|
||||
firewall_port_name: internal
|
||||
firewall_port_direction: inbound
|
||||
position: 1
|
||||
permission: 1
|
||||
source_ip_id: 2 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: PERMIT
|
||||
src_ip: 192.168.0.10
|
||||
dst_ip: ALL
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: ALL
|
||||
src_wildcard: 0
|
||||
dst_wildcard: 0
|
||||
2:
|
||||
action: firewall_acl_remove_rule
|
||||
options:
|
||||
@@ -134,12 +134,12 @@ agents:
|
||||
firewall_port_name: internal
|
||||
firewall_port_direction: outbound
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 2 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 2
|
||||
dest_port_id: 3
|
||||
protocol_id: 2
|
||||
permission: DENY
|
||||
src_ip: 192.168.0.10 # client 1
|
||||
dest_ip: ALL # ALL
|
||||
src_port: ARP
|
||||
dst_port: DNS
|
||||
protocol_name: ICMP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
4:
|
||||
@@ -156,12 +156,12 @@ agents:
|
||||
firewall_port_name: dmz
|
||||
firewall_port_direction: inbound
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 3 # dmz_server
|
||||
dest_ip_id: 2 # client_1
|
||||
source_port_id: 4
|
||||
dest_port_id: 4
|
||||
protocol_id: 4
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.10 # dmz_server
|
||||
dst_ip: 192.168.0.10 # client_1
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
protocol_name: UDP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
6:
|
||||
@@ -178,12 +178,12 @@ agents:
|
||||
firewall_port_name: dmz
|
||||
firewall_port_direction: outbound
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 3 # dmz_server
|
||||
dest_ip_id: 2 # client_1
|
||||
source_port_id: 4
|
||||
dest_port_id: 4
|
||||
protocol_id: 3
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.10 # dmz_server
|
||||
dst_ip: 192.168.0.10 # client_1
|
||||
src_port: HTTP
|
||||
dst_port: HTTP
|
||||
protocol_name: TCP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
8:
|
||||
@@ -200,12 +200,12 @@ agents:
|
||||
firewall_port_name: external
|
||||
firewall_port_direction: inbound
|
||||
position: 10
|
||||
permission: 2
|
||||
source_ip_id: 4 # external_computer
|
||||
dest_ip_id: 3 # dmz
|
||||
source_port_id: 5
|
||||
dest_port_id: 5
|
||||
protocol_id: 2
|
||||
permission: DENY
|
||||
src_ip: 192.168.20.10 # external_computer
|
||||
dst_ip: 192.168.10.10 # dmz
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
protocol_name: ICMP
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
10:
|
||||
@@ -222,12 +222,12 @@ agents:
|
||||
firewall_port_name: external
|
||||
firewall_port_direction: outbound
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 4 # external_computer
|
||||
dest_ip_id: 2 # client_1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
permission: DENY
|
||||
src_ip: 192.168.20.10 # external_computer
|
||||
dst_ip: 192.168.0.10 # client_1
|
||||
src_port: NONE
|
||||
dst_port: NONE
|
||||
protocol_name: none
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
12:
|
||||
|
||||
@@ -163,8 +163,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
print(agent.most_recent_action)
|
||||
game.step()
|
||||
|
||||
print(agent.most_recent_action)
|
||||
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
|
||||
print(router.acl.show())
|
||||
assert router.acl.num_rules == 6
|
||||
@@ -653,9 +654,9 @@ def test_firewall_acl_add_remove_rule_integration():
|
||||
assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
|
||||
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
|
||||
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
|
||||
assert firewall.external_outbound_acl.acl[1].dst_port is None
|
||||
assert firewall.external_outbound_acl.acl[1].src_port is None
|
||||
assert firewall.external_outbound_acl.acl[1].protocol is None
|
||||
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"]
|
||||
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"]
|
||||
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
|
||||
|
||||
env.step(12) # Remove ACL rule from External Outbound
|
||||
assert firewall.external_outbound_acl.num_rules == 1
|
||||
|
||||
Reference in New Issue
Block a user