#2912 - Actioning review comments
This commit is contained in:
@@ -29,4 +29,5 @@ __all__ = (
|
||||
"node",
|
||||
"service",
|
||||
"session",
|
||||
"ActionManager",
|
||||
)
|
||||
|
||||
@@ -36,6 +36,4 @@ class AbstractAction(BaseModel):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAction":
|
||||
"""Create an action component from a config dictionary."""
|
||||
if not config.get("type"):
|
||||
config.update({"type": cls.__name__})
|
||||
return cls(config=cls.ConfigSchema(**config))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import List, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
@@ -12,34 +14,48 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"):
|
||||
"""Base class for ACL actions."""
|
||||
class ACLAddRuleAbstractAction(AbstractAction, identifier="acl_add_rule_abstract_action"):
|
||||
"""Base abstract class for ACL add rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL abstract actions."""
|
||||
"""Configuration Schema base for ACL add rule abstract actions."""
|
||||
|
||||
src_ip: str
|
||||
protocol_name: str
|
||||
permission: str
|
||||
position: int
|
||||
src_ip: str
|
||||
dst_ip: str
|
||||
src_port: str
|
||||
dst_port: str
|
||||
src_wildcard: int
|
||||
dst_wildcard: int
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
|
||||
"""Base abstract class for ACL remove rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL remove rule abstract actions."""
|
||||
|
||||
src_ip: str
|
||||
protocol_name: str
|
||||
position: int
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
config: "RouterACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for RouterACLAddRuleAction."""
|
||||
|
||||
target_router: str
|
||||
permission: str
|
||||
protocol_name: str
|
||||
position: int
|
||||
src_ip: str
|
||||
src_wildcard: int
|
||||
source_port: str
|
||||
dst_ip: str
|
||||
dst_wildcard: int
|
||||
dst_port: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
@@ -62,16 +78,15 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"):
|
||||
]
|
||||
|
||||
|
||||
class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_rule"):
|
||||
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
config: "RouterACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RouterACLRemoveRuleAction."""
|
||||
|
||||
target_router: str
|
||||
position: int
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
@@ -79,31 +94,22 @@ class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_ru
|
||||
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
|
||||
|
||||
|
||||
class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_rule"):
|
||||
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
|
||||
"""Action which adds a rule to a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAbstractAction.ConfigSchema):
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLAddRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
position: int
|
||||
permission: str
|
||||
src_ip: str
|
||||
dest_ip: str
|
||||
src_port: str
|
||||
dst_port: str
|
||||
protocol_name: str
|
||||
source_wildcard_id: int
|
||||
dest_wildcard_id: int
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.protocol_name == None:
|
||||
if config.protocol_name is None:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
if config.src_ip == 0:
|
||||
return ["do_nothing"] # invalid formulation
|
||||
@@ -121,27 +127,26 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
config.src_ip,
|
||||
config.source_wildcard_id,
|
||||
config.src_wildcard,
|
||||
config.src_port,
|
||||
config.dest_ip,
|
||||
config.dest_wildcard_id,
|
||||
config.dst_ip,
|
||||
config.dst_wildcard,
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"):
|
||||
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
|
||||
"""Action which removes a rule from a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLRemoveRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
position: int
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> List[str]:
|
||||
|
||||
@@ -34,8 +34,6 @@ class NodeApplicationAbstractAction(AbstractAction, identifier="node_application
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.application_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
@@ -103,8 +101,6 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="no
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
@@ -129,8 +125,6 @@ class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="nod
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo
|
||||
from pydantic import ConfigDict, Field, field_validator, ValidationInfo
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction, ActionManager
|
||||
from primaite.interface.request import RequestFormat
|
||||
@@ -27,7 +27,6 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureRansomwareScriptAction."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_name: str
|
||||
server_ip_address: Optional[str]
|
||||
server_password: Optional[str]
|
||||
@@ -109,17 +108,7 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
|
||||
@classmethod
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
configuration = ConfigureC2BeaconAction.ConfigSchema(
|
||||
c2_server_ip_address=config.c2_server_ip_address,
|
||||
keep_alive_frequency=config.keep_alive_frequency,
|
||||
masquerade_port=config.masquerade_port,
|
||||
masquerade_protocol=config.masquerade_protocol,
|
||||
)
|
||||
|
||||
ConfigureC2BeaconAction.ConfigSchema.model_validate(configuration) # check that options adhere to schema
|
||||
|
||||
configuration = []
|
||||
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", configuration]
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ agents:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
@@ -45,7 +44,9 @@ class ActionManager:
|
||||
def __init__(
|
||||
self,
|
||||
actions: List[Dict], # stores list of actions available to agent
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
act_map: Optional[
|
||||
Dict[int, Dict]
|
||||
] = None, # allows restricting set of possible actions - TODO: Refactor to be a list?
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -79,43 +80,6 @@ class ActionManager:
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Generate a list of all the possible actions that could be taken.
|
||||
|
||||
This enumerates all actions all combinations of parameters you could choose for those actions. The output
|
||||
of this function is intended to populate the self.action_map parameter in the situation where the user provides
|
||||
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
|
||||
|
||||
The enumeration relies on the Actions' `shape` attribute.
|
||||
|
||||
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
|
||||
An example output could be:
|
||||
{0: ("do_nothing", {'dummy': 0}),
|
||||
1: ("node_os_scan", {'node_name': computer}),
|
||||
2: ("node_os_scan", {'node_name': server}),
|
||||
3: ("node_folder_scan", {'node_name:computer, folder_name:downloads}),
|
||||
... #etc...
|
||||
}
|
||||
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
|
||||
"""
|
||||
all_action_possibilities = []
|
||||
for act_name, action in self.actions.items():
|
||||
param_names = list(action.shape.keys())
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
|
||||
for j in range(len(param_combinations))
|
||||
]
|
||||
)
|
||||
|
||||
return {i: p for i, p in enumerate(all_action_possibilities)}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format."""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
@@ -125,8 +89,9 @@ class ActionManager:
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier].from_config(config=action_options)
|
||||
return act_obj.form_request(config=act_obj.config)
|
||||
act_class = AbstractAction._registry[action_identifier]
|
||||
config = act_class.ConfigSchema(**action_options)
|
||||
return act_class.form_request(config=config)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
@@ -134,7 +99,7 @@ class ActionManager:
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager":
|
||||
def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": # noqa: F821
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user