#2912 - Updates so that all tests within test_actions.py pass

This commit is contained in:
Charlie Crane
2024-11-13 15:32:48 +00:00
parent ed020f005f
commit 95fbe45137
6 changed files with 63 additions and 69 deletions

View File

@@ -175,10 +175,6 @@ class NodeApplicationInstallAction(AbstractAction):
class ConfigureDatabaseClientAction(AbstractAction):
"""Action which sets config parameters for a database client on a node."""
model_config = ConfigDict(extra="forbid")
server_ip_address: Optional[str] = None
server_password: Optional[str] = None
class _Opts(BaseModel):
"""Schema for options that can be passed to this action."""
@@ -204,7 +200,6 @@ class ConfigureRansomwareScriptAction(AbstractAction):
class _Opts(BaseModel, AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this option."""
node_name: str
model_config = ConfigDict(extra="forbid")
server_ip_address: Optional[str] = None
server_password: Optional[str] = None

View File

@@ -43,6 +43,5 @@ class AbstractAction(BaseModel):
def from_config(cls, config: Dict) -> "AbstractAction":
"""Create an action component from a config dictionary."""
for attribute, value in config.items():
if not hasattr(cls.ConfigSchema, attribute):
setattr(cls.ConfigSchema, attribute, value)
setattr(cls.ConfigSchema, attribute, value)
return cls

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import List, Literal
from typing import List, Literal, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
@@ -92,6 +92,12 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
num_protocols: int
num_permissions: int = 3
permission: str
target_firewall_nodename: str
src_ip: str
dst_ip: str
dst_wildcard: str
src_port: Union[int| None]
dst_port: Union[int | None]
class ConfigSchema(ACLAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLAddRuleAction."""
@@ -102,29 +108,22 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
num_protocols: int
num_permissions: int = 3
permission: str
target_firewall_nodename: str
src_ip: str
dst_ip: str
dst_wildcard: str
src_port: Union[int| None]
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.protocol_id == 0:
if config.protocol_name == None:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if config.source_ip_id == 0:
if config.src_ip == 0:
return ["do_nothing"] # invalid formulation
elif config.source_ip_id == 1:
src_ip = "ALL"
else:
# src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
pass
if config.src_port == 0:
return ["do_nothing"] # invalid configuration.
if config.source_port_id == 0:
return ["do_nothing"] # invalid formulation
elif config.source_port_id == 1:
src_port = "ALL"
else:
# src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
pass
return [
"network",
@@ -136,9 +135,9 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
"add_rule",
config.permission,
config.protocol_name,
str(src_ip),
config.src_ip,
config.src_wildcard,
src_port,
config.src_port,
config.dst_ip,
config.dst_wildcard,
config.dst_port,

View File

@@ -37,7 +37,7 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans
if config.node_name is None:
return ["do_nothing"]
ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config]
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config.model_config]
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):

View File

@@ -112,14 +112,14 @@ agents:
firewall_port_name: internal
firewall_port_direction: inbound
position: 1
permission: 1
source_ip_id: 2 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
permission: PERMIT
src_ip: 192.168.0.10
dst_ip: ALL
src_port: ALL
dst_port: ALL
protocol_name: ALL
src_wildcard: 0
dst_wildcard: 0
2:
action: firewall_acl_remove_rule
options:
@@ -134,12 +134,12 @@ agents:
firewall_port_name: internal
firewall_port_direction: outbound
position: 1
permission: 2
source_ip_id: 2 # client 1
dest_ip_id: 1 # ALL
source_port_id: 2
dest_port_id: 3
protocol_id: 2
permission: DENY
src_ip: 192.168.0.10 # client 1
dest_ip: ALL # ALL
src_port: ARP
dst_port: DNS
protocol_name: ICMP
source_wildcard_id: 0
dest_wildcard_id: 0
4:
@@ -156,12 +156,12 @@ agents:
firewall_port_name: dmz
firewall_port_direction: inbound
position: 1
permission: 2
source_ip_id: 3 # dmz_server
dest_ip_id: 2 # client_1
source_port_id: 4
dest_port_id: 4
protocol_id: 4
permission: DENY
src_ip: 192.168.10.10 # dmz_server
dst_ip: 192.168.0.10 # client_1
src_port: HTTP
dst_port: HTTP
protocol_name: UDP
source_wildcard_id: 0
dest_wildcard_id: 0
6:
@@ -178,12 +178,12 @@ agents:
firewall_port_name: dmz
firewall_port_direction: outbound
position: 2
permission: 2
source_ip_id: 3 # dmz_server
dest_ip_id: 2 # client_1
source_port_id: 4
dest_port_id: 4
protocol_id: 3
permission: DENY
src_ip: 192.168.10.10 # dmz_server
dst_ip: 192.168.0.10 # client_1
src_port: HTTP
dst_port: HTTP
protocol_name: TCP
source_wildcard_id: 0
dest_wildcard_id: 0
8:
@@ -200,12 +200,12 @@ agents:
firewall_port_name: external
firewall_port_direction: inbound
position: 10
permission: 2
source_ip_id: 4 # external_computer
dest_ip_id: 3 # dmz
source_port_id: 5
dest_port_id: 5
protocol_id: 2
permission: DENY
src_ip: 192.168.20.10 # external_computer
dst_ip: 192.168.10.10 # dmz
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
protocol_name: ICMP
source_wildcard_id: 0
dest_wildcard_id: 0
10:
@@ -222,12 +222,12 @@ agents:
firewall_port_name: external
firewall_port_direction: outbound
position: 1
permission: 2
source_ip_id: 4 # external_computer
dest_ip_id: 2 # client_1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
permission: DENY
src_ip: 192.168.20.10 # external_computer
dst_ip: 192.168.0.10 # client_1
src_port: NONE
dst_port: NONE
protocol_name: none
source_wildcard_id: 0
dest_wildcard_id: 0
12:

View File

@@ -163,8 +163,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
},
)
agent.store_action(action)
print(agent.most_recent_action)
game.step()
print(agent.most_recent_action)
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
print(router.acl.show())
assert router.acl.num_rules == 6
@@ -653,9 +654,9 @@ def test_firewall_acl_add_remove_rule_integration():
assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
assert firewall.external_outbound_acl.acl[1].dst_port is None
assert firewall.external_outbound_acl.acl[1].src_port is None
assert firewall.external_outbound_acl.acl[1].protocol is None
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
env.step(12) # Remove ACL rule from External Outbound
assert firewall.external_outbound_acl.num_rules == 1