From 1d2705eb1b95bb369926bc5858feef7a56cf4abe Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 15 Aug 2024 20:16:11 +0100 Subject: [PATCH] #2769 - Add user login observations --- .../observations/firewall_observation.py | 20 +- .../agent/observations/host_observations.py | 21 + .../agent/observations/node_observations.py | 8 + .../agent/observations/router_observation.py | 17 +- .../simulator/network/hardware/base.py | 3 +- tests/assets/configs/data_manipulation.yaml | 942 ++++++++++++++++++ .../observations/test_user_observations.py | 89 ++ 7 files changed, 1096 insertions(+), 4 deletions(-) create mode 100644 tests/assets/configs/data_manipulation.yaml create mode 100644 tests/integration_tests/game_layer/observations/test_user_observations.py diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 4f1a9d90..42ceaff0 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -10,6 +10,7 @@ from primaite import getLogger from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list: List[int], protocol_list: List[str], num_rules: int, + include_users: bool, ) -> None: """ Initialise a firewall observation instance. @@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :type protocol_list: List[str] :param num_rules: Number of rules configured in the firewall. :type num_rules: int + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" self.ports: List[PortObservation] = [ PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] @@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. :rtype: ObsType """ + firewall_state = access_from_nested_dict(state, self.where) + if firewall_state is NOT_PRESENT_IN_STATE: + return self.default_observation obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, "ACL": { @@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + sess = firewall_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 7053d019..4419ccc7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -52,6 +52,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -69,6 +71,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + include_users: bool, ) -> None: """ Initialise a host observation instance. @@ -103,10 +106,15 @@ class HostObservation(AbstractObservation, identifier="HOST"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.include_num_access = include_num_access + self.include_users = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services @@ -165,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: self.default_observation["num_file_creations"] = 0 self.default_observation["num_file_deletions"] = 0 + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -192,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -216,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: shape["num_file_creations"] = spaces.Discrete(4) shape["num_file_deletions"] = spaces.Discrete(4) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod @@ -273,4 +293,5 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index c68531f8..e263cadb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -46,6 +46,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Flag to include the number of accesses.""" file_system_requires_scan: bool = True """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + include_users: Optional[bool] = True + """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" ip_list: Optional[List[str]] = None @@ -191,6 +193,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.include_users is None: + host_config.include_users = config.include_users for router_config in config.routers: if router_config.num_ports is None: @@ -205,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.protocol_list = config.protocol_list if router_config.num_rules is None: router_config.num_rules = config.num_rules + if router_config.include_users is None: + router_config.include_users = config.include_users for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -217,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.protocol_list = config.protocol_list if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules + if firewall_config.include_users is None: + firewall_config.include_users = config.include_users hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index f1d4ec8e..d064936a 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports: List[PortObservation], num_ports: int, acl: ACLObservation, + include_users: bool, ) -> None: """ Initialise a router observation instance. @@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :type num_ports: int :param acl: ACL observation representing the access control list of the router. :type acl: ACLObservation + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.ports: List[PortObservation] = ports self.acl: ACLObservation = acl self.num_ports: int = num_ports - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" while len(self.ports) < num_ports: self.ports.append(PortObservation(where=None)) while len(self.ports) > num_ports: @@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): obs["ACL"] = self.acl.observe(state) if self.ports: obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 68b45c2e..b0c48e7d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1265,7 +1265,8 @@ class UserSessionManager(Service): :return: A dictionary representing the current state. """ state = super().describe_state() - state["active_remote_logins"] = len(self.remote_sessions) + state["current_local_user"] = None if not self.local_session else self.local_session.user.username + state["active_remote_sessions"] = list(self.remote_sessions.keys()) return state @property diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml new file mode 100644 index 00000000..97442903 --- /dev/null +++ b/tests/assets/configs/data_manipulation.yaml @@ -0,0 +1,942 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: switch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py new file mode 100644 index 00000000..ca5e2543 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -0,0 +1,89 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +@pytest.fixture +def env_with_ssh() -> PrimaiteGymEnv: + """Build data manipulation environment with SSH port open on router.""" + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.agent.flatten_obs = False + router: Router = env.game.simulation.network.get_node_by_hostname("router_1") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + return env + + +def extract_login_numbers_from_obs(obs): + """Traverse the observation dictionary and return number of user sessions for all nodes.""" + login_nums = {} + for node_name, node_obs in obs["NODES"].items(): + login_nums[node_name] = node_obs.get("users") + return login_nums + + +class TestUserObservations: + """Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins.""" + + def test_no_sessions_at_episode_start(self, env_with_ssh): + """Test that all of the login observations start at 0 before any logins occur.""" + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_single_login(self, env_with_ssh: PrimaiteGymEnv): + """Test that performing a remote login increases the remote_sessions observation by 1.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + assert db_srv_logins_obs["local_login"] == 0 + assert db_srv_logins_obs["remote_sessions"] == 1 + for o in logins_obs.values(): # the remaining obs after popping HOST2 + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_logout(self, env_with_ssh: PrimaiteGymEnv): + """Test that remote_sessions observation correctly decreases upon logout.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv): + """Log in from 5 remote places and check that only a max of 3 is shown in the observation.""" + MAX_OBSERVABLE_SESSIONS = 3 + # Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_session_manager.remote_session_timeout_steps = 20 + db_srv.user_session_manager.max_remote_sessions = 5 + node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller") + + for i, node_name in enumerate(node_names): + node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name) + node.terminal._send_remote_login("admin", "admin", "192.168.1.14") + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1) + assert len(db_srv.user_session_manager.remote_sessions) == i + 1