#2912 - Actioning review comments

This commit is contained in:
Charlie Crane
2024-12-06 15:12:31 +00:00
parent 8f610a3dd9
commit be174b6477
6 changed files with 53 additions and 101 deletions

View File

@@ -29,4 +29,5 @@ __all__ = (
"node",
"service",
"session",
"ActionManager",
)

View File

@@ -36,6 +36,4 @@ class AbstractAction(BaseModel):
@classmethod
def from_config(cls, config: Dict) -> "AbstractAction":
"""Create an action component from a config dictionary."""
if not config.get("type"):
config.update({"type": cls.__name__})
return cls(config=cls.ConfigSchema(**config))

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import List, Literal, Union
from __future__ import annotations
from typing import List
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
@@ -12,34 +14,48 @@ __all__ = (
)
class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"):
"""Base class for ACL actions."""
class ACLAddRuleAbstractAction(AbstractAction, identifier="acl_add_rule_abstract_action"):
"""Base abstract class for ACL add rule actions."""
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL abstract actions."""
"""Configuration Schema base for ACL add rule abstract actions."""
src_ip: str
protocol_name: str
permission: str
position: int
src_ip: str
dst_ip: str
src_port: str
dst_port: str
src_wildcard: int
dst_wildcard: int
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
"""Base abstract class for ACL remove rule actions."""
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL remove rule abstract actions."""
src_ip: str
protocol_name: str
position: int
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
config: "RouterACLAddRuleAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
"""Configuration Schema for RouterACLAddRuleAction."""
target_router: str
permission: str
protocol_name: str
position: int
src_ip: str
src_wildcard: int
source_port: str
dst_ip: str
dst_wildcard: int
dst_port: str
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
@@ -62,16 +78,15 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
]
class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_rule"):
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
"""Action which removes a rule from a router's ACL."""
config: "RouterACLRemoveRuleAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
"""Configuration schema for RouterACLRemoveRuleAction."""
target_router: str
position: int
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
@@ -79,31 +94,22 @@ class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_ru
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_rule"):
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
"""Action which adds a rule to a firewall port's ACL."""
config: "FirewallACLAddRuleAction.ConfigSchema"
class ConfigSchema(ACLAbstractAction.ConfigSchema):
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLAddRuleAction."""
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
position: int
permission: str
src_ip: str
dest_ip: str
src_port: str
dst_port: str
protocol_name: str
source_wildcard_id: int
dest_wildcard_id: int
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.protocol_name == None:
if config.protocol_name is None:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if config.src_ip == 0:
return ["do_nothing"] # invalid formulation
@@ -121,27 +127,26 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
config.permission,
config.protocol_name,
config.src_ip,
config.source_wildcard_id,
config.src_wildcard,
config.src_port,
config.dest_ip,
config.dest_wildcard_id,
config.dst_ip,
config.dst_wildcard,
config.dst_port,
config.position,
]
class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"):
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
"""Action which removes a rule from a firewall port's ACL."""
config: "FirewallACLRemoveRuleAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLRemoveRuleAction."""
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
position: int
@classmethod
def form_request(cls, config: ConfigSchema) -> List[str]:

View File

@@ -34,8 +34,6 @@ class NodeApplicationAbstractAction(AbstractAction, identifier="node_application
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.application_name is None:
return ["do_nothing"]
return [
"network",
"node",
@@ -103,8 +101,6 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="no
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None:
return ["do_nothing"]
return [
"network",
"node",
@@ -129,8 +125,6 @@ class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="nod
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None:
return ["do_nothing"]
return [
"network",
"node",

View File

@@ -1,8 +1,8 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
from pydantic import ConfigDict, Field, field_validator, ValidationInfo
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
from primaite.interface.request import RequestFormat
@@ -27,7 +27,6 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for ConfigureRansomwareScriptAction."""
model_config = ConfigDict(extra="forbid")
node_name: str
server_ip_address: Optional[str]
server_password: Optional[str]
@@ -109,17 +108,7 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
@classmethod
def form_request(self, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
configuration = ConfigureC2BeaconAction.ConfigSchema(
c2_server_ip_address=config.c2_server_ip_address,
keep_alive_frequency=config.keep_alive_frequency,
masquerade_port=config.masquerade_port,
masquerade_protocol=config.masquerade_protocol,
)
ConfigureC2BeaconAction.ConfigSchema.model_validate(configuration) # check that options adhere to schema
configuration = []
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", configuration]

View File

@@ -13,8 +13,7 @@ agents:
from __future__ import annotations
import itertools
from typing import Dict, List, Literal, Optional, Tuple
from typing import Dict, List, Optional, Tuple
from gymnasium import spaces
@@ -45,7 +44,9 @@ class ActionManager:
def __init__(
self,
actions: List[Dict], # stores list of actions available to agent
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
act_map: Optional[
Dict[int, Dict]
] = None, # allows restricting set of possible actions - TODO: Refactor to be a list?
*args,
**kwargs,
) -> None:
@@ -79,43 +80,6 @@ class ActionManager:
# make sure all numbers between 0 and N are represented as dict keys in action map
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
def _enumerate_actions(
self,
) -> Dict[int, Tuple[str, Dict]]:
"""Generate a list of all the possible actions that could be taken.
This enumerates all actions all combinations of parameters you could choose for those actions. The output
of this function is intended to populate the self.action_map parameter in the situation where the user provides
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
The enumeration relies on the Actions' `shape` attribute.
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
An example output could be:
{0: ("do_nothing", {'dummy': 0}),
1: ("node_os_scan", {'node_name': computer}),
2: ("node_os_scan", {'node_name': server}),
3: ("node_folder_scan", {'node_name:computer, folder_name:downloads}),
... #etc...
}
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
"""
all_action_possibilities = []
for act_name, action in self.actions.items():
param_names = list(action.shape.keys())
num_possibilities = list(action.shape.values())
possibilities = [range(n) for n in num_possibilities]
param_combinations = list(itertools.product(*possibilities))
all_action_possibilities.extend(
[
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
for j in range(len(param_combinations))
]
)
return {i: p for i, p in enumerate(all_action_possibilities)}
def get_action(self, action: int) -> Tuple[str, Dict]:
"""Produce action in CAOS format."""
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
@@ -125,8 +89,9 @@ class ActionManager:
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_obj = self.actions[action_identifier].from_config(config=action_options)
return act_obj.form_request(config=act_obj.config)
act_class = AbstractAction._registry[action_identifier]
config = act_class.ConfigSchema(**action_options)
return act_class.form_request(config=config)
@property
def space(self) -> spaces.Space:
@@ -134,7 +99,7 @@ class ActionManager:
return spaces.Discrete(len(self.action_map))
@classmethod
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": # noqa: F821
"""
Construct an ActionManager from a config definition.