diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aba9e6b..850b216d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. +- Added actions to establish SSH connections, send commands remotely and terminate SSH connections. +- Added actions to change users' passwords. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. +- Added reward calculation details to AgentHistoryItem. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. - It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` +- Node observations can now be configured to show the number of active local and remote logins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 7263cfc1..2a0c5351 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1071,6 +1071,83 @@ class NodeNetworkServiceReconAction(AbstractAction): ] +class NodeAccountsChangePasswordAction(AbstractAction): + """Action which changes the password for a user.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserManager", + "change_password", + username, + current_password, + new_password, + ] + + +class NodeSessionsRemoteLoginAction(AbstractAction): + """Action which performs a remote session login.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "ssh_to_remote", + username, + password, + remote_ip, + ] + + +class NodeSessionsRemoteLogoutAction(AbstractAction): + """Action which performs a remote session logout.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: str, remote_ip: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip] + + +class NodeSendRemoteCommandAction(AbstractAction): + """Action which sends a terminal command to a remote node via SSH.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "send_remote_command", + remote_ip, + {"command": command}, + ] + + class ActionManager: """Class which manages the action space for an agent.""" @@ -1122,6 +1199,10 @@ class ActionManager: "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, "CONFIGURE_DOSBOT": ConfigureDoSBotAction, + "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, + "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, + "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, + "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index f57dc191..14b97821 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel): reward: Optional[float] = None + reward_info: Dict[str, Any] = {} + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 4f1a9d90..42ceaff0 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -10,6 +10,7 @@ from primaite import getLogger from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list: List[int], protocol_list: List[str], num_rules: int, + include_users: bool, ) -> None: """ Initialise a firewall observation instance. @@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :type protocol_list: List[str] :param num_rules: Number of rules configured in the firewall. :type num_rules: int + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" self.ports: List[PortObservation] = [ PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] @@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. :rtype: ObsType """ + firewall_state = access_from_nested_dict(state, self.where) + if firewall_state is NOT_PRESENT_IN_STATE: + return self.default_observation obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, "ACL": { @@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + sess = firewall_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 7053d019..4419ccc7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -52,6 +52,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -69,6 +71,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + include_users: bool, ) -> None: """ Initialise a host observation instance. @@ -103,10 +106,15 @@ class HostObservation(AbstractObservation, identifier="HOST"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.include_num_access = include_num_access + self.include_users = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services @@ -165,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: self.default_observation["num_file_creations"] = 0 self.default_observation["num_file_deletions"] = 0 + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -192,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -216,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: shape["num_file_creations"] = spaces.Discrete(4) shape["num_file_deletions"] = spaces.Discrete(4) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod @@ -273,4 +293,5 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index c68531f8..e263cadb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -46,6 +46,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Flag to include the number of accesses.""" file_system_requires_scan: bool = True """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + include_users: Optional[bool] = True + """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" ip_list: Optional[List[str]] = None @@ -191,6 +193,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.include_users is None: + host_config.include_users = config.include_users for router_config in config.routers: if router_config.num_ports is None: @@ -205,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.protocol_list = config.protocol_list if router_config.num_rules is None: router_config.num_rules = config.num_rules + if router_config.include_users is None: + router_config.include_users = config.include_users for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -217,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.protocol_list = config.protocol_list if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules + if firewall_config.include_users is None: + firewall_config.include_users = config.include_users hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index f1d4ec8e..d064936a 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports: List[PortObservation], num_ports: int, acl: ACLObservation, + include_users: bool, ) -> None: """ Initialise a router observation instance. @@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :type num_ports: int :param acl: ACL observation representing the access control list of the router. :type acl: ACLObservation + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.ports: List[PortObservation] = ports self.acl: ACLObservation = acl self.num_ports: int = num_ports - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" while len(self.ports) < num_ports: self.ports.append(PortObservation(where=None)) while len(self.ports) > num_ports: @@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): obs["ACL"] = self.acl.observe(state) if self.ports: obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 321df098..73bc7b11 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,7 +47,15 @@ class AbstractReward: @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -67,7 +75,15 @@ class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward): def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. - :param state: The current state of the simulation. + :param state: Current simulation state :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: @@ -322,6 +342,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last successful request returned was able to connect to the database server, because there has been an unsuccessful request since. + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ request_attempted = last_action_response.request == [ "network", @@ -333,10 +359,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): ] if request_attempted: # if agent makes request, always recalculate fresh value + last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status} self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 elif not self.sticky: # if no new request and not sticky, set reward to 0 + last_action_response.reward_info = {"connection_attempt_status": "n/a"} self.reward = 0.0 else: # if no new request and sticky, reuse reward value from last step + last_action_response.reward_info = {"connection_attempt_status": "n/a"} pass return self.reward @@ -384,7 +413,15 @@ class SharedReward(AbstractReward): """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Simply access the other agent's reward and return it.""" + """Simply access the other agent's reward and return it. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return self.callback(self.agent_name) @classmethod @@ -417,7 +454,15 @@ class ActionPenalty(AbstractReward): self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the penalty to be applied.""" + """Calculate the penalty to be applied. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ if last_action_response.action == "DONOTHING": return self.do_nothing_penalty else: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 1441c93b..b0c48e7d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -990,6 +990,7 @@ class UserManager(Service): if user and user.password == current_password: user.password = new_password self.sys_log.info(f"{self.name}: Password changed for {username}") + self._user_session_manager._logout_user(user=user) return True self.sys_log.info(f"{self.name}: Password change failed for {username}") return False @@ -1027,6 +1028,10 @@ class UserManager(Service): self.sys_log.info(f"{self.name}: Failed to enable user: {username}") return False + @property + def _user_session_manager(self) -> "UserSessionManager": + return self.software_manager.software["UserSessionManager"] # noqa + class UserSession(SimComponent): """ @@ -1260,7 +1265,8 @@ class UserSessionManager(Service): :return: A dictionary representing the current state. """ state = super().describe_state() - state["active_remote_logins"] = len(self.remote_sessions) + state["current_local_user"] = None if not self.local_session else self.local_session.user.username + state["active_remote_sessions"] = list(self.remote_sessions.keys()) return state @property @@ -1435,6 +1441,19 @@ class UserSessionManager(Service): """ return self._logout(local=False, remote_session_id=remote_session_id) + def _logout_user(self, user: Union[str, User]) -> bool: + """End a user session by username or user object.""" + if isinstance(user, str): + user = self._user_manager.users[user] # grab user object from username + for sess_id, session in self.remote_sessions.items(): + if session.user is user: + self._logout(local=False, remote_session_id=sess_id) + return True + if self.local_user_logged_in and self.local_session.user is user: + self.local_logout() + return True + return False + @property def local_user_logged_in(self) -> bool: """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 876b1694..406facd1 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor +# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. @@ -92,7 +94,7 @@ class LocalTerminalConnection(TerminalClientConnection): if not self.is_active: self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") return None - return self.parent_terminal.execute(command, connection_id=self.connection_uuid) + return self.parent_terminal.execute(command) class RemoteTerminalConnection(TerminalClientConnection): @@ -162,22 +164,6 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" rm = super()._init_request_manager() - rm.add_request( - "send", - request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), - ) - - def _login(request: RequestFormat, context: Dict) -> RequestResponse: - login = self._process_local_login(username=request[0], password=request[1]) - if login: - return RequestResponse( - status="success", - data={ - "ip_address": login.ip_address, - }, - ) - else: - return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) @@ -191,10 +177,34 @@ class Terminal(Service): else: return RequestResponse(status="failure", data={}) + rm.add_request( + "ssh_to_remote", + request_type=RequestType(func=_remote_login), + ) + + def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse: + """Logoff from remote connection.""" + ip_address = IPv4Address(request[0]) + remote_connection = self._get_connection_from_ip(ip_address=ip_address) + if remote_connection: + outcome = self._disconnect(remote_connection.connection_uuid) + if outcome: + return RequestResponse( + status="success", + data={}, + ) + else: + return RequestResponse( + status="failure", + data={"reason": "No remote connection held."}, + ) + + rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff)) + def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" - command: str = request[0] - ip_address: IPv4Address = IPv4Address(request[1]) + ip_address: IPv4Address = IPv4Address(request[0]) + command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: outcome = remote_connection.execute(command) @@ -209,30 +219,11 @@ class Terminal(Service): data={}, ) - def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: - """Logoff from connection.""" - connection_uuid = request[0] - self.parent.user_session_manager.local_logout(connection_uuid) - self._disconnect(connection_uuid) - return RequestResponse(status="success", data={}) - rm.add_request( - "Login", - request_type=RequestType(func=_login), - ) - - rm.add_request( - "Remote Login", - request_type=RequestType(func=_remote_login), - ) - - rm.add_request( - "Execute", + "send_remote_command", request_type=RequestType(func=remote_execute_request), ) - rm.add_request("Logoff", request_type=RequestType(func=_logoff)) - return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: @@ -280,13 +271,9 @@ class Terminal(Service): if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None - connection_request_id = str(uuid4()) - self._client_connection_requests[connection_request_id] = None if ip_address: # Assuming that if IP is passed we are connecting to remote - return self._send_remote_login( - username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id - ) + return self._send_remote_login(username=username, password=password, ip_address=ip_address) else: return self._process_local_login(username=username, password=password) @@ -313,6 +300,9 @@ class Terminal(Service): def _check_client_connection(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" + if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id): + self._disconnect(connection_id) + return False return connection_id in self._connections def _send_remote_login( @@ -320,32 +310,24 @@ class Terminal(Service): username: str, password: str, ip_address: IPv4Address, - connection_request_id: str, + connection_request_id: Optional[str] = None, is_reattempt: bool = False, ) -> Optional[RemoteTerminalConnection]: """Send a remote login attempt and connect to Node. :param: username: Username used to connect to the remote node. :type: username: str - :param: password: Password used to connect to the remote node :type: password: str - :param: ip_address: Target Node IP address for login attempt. :type: ip_address: IPv4Address - - :param: connection_request_id: Connection Request ID - :type: connection_request_id: str - + :param: connection_request_id: Connection Request ID, if not provided, a new one is generated + :type: connection_request_id: Optional[str] :param: is_reattempt: True if the request has been reattempted. Default False. :type: is_reattempt: Optional[bool] - :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. - """ - self.sys_log.info( - f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" - ) + connection_request_id = connection_request_id or str(uuid4()) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -360,6 +342,9 @@ class Terminal(Service): self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") return None + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password) diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml new file mode 100644 index 00000000..97442903 --- /dev/null +++ b/tests/assets/configs/data_manipulation.yaml @@ -0,0 +1,942 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: switch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/conftest.py b/tests/conftest.py index 2d605c94..abc851c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -458,6 +458,10 @@ def game_and_agent(): {"type": "HOST_NIC_DISABLE"}, {"type": "NETWORK_PORT_ENABLE"}, {"type": "NETWORK_PORT_DISABLE"}, + {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, + {"type": "SSH_TO_REMOTE"}, + {"type": "SESSIONS_REMOTE_LOGOFF"}, + {"type": "NODE_SEND_REMOTE_COMMAND"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py new file mode 100644 index 00000000..d011c1e8 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -0,0 +1,166 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import UserManager +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + + return (game, agent) + + +def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if not connection_established: + pytest.fail("Remote SSH connection could not be established") + + +def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "wrong_password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "failure" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if connection_established: + pytest.fail("Remote SSH connection was established despite wrong password") + + +def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_um: UserManager = server_1.software_manager.software["UserManager"] + server_1_um.add_user("user123", "password", is_admin=True) + + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + assert server_1_um.users["user123"].password == "different_password" + + +def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + # Log in remotely + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + + # Change password + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + # Assert that the user cannot execute an action + action = ( + "NODE_SEND_REMOTE_COMMAND", + { + "node_id": 0, + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert server_1.file_system.get_folder("folder123") is None + assert server_1.file_system.get_file("folder123", "doggo.pdf") is None diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 99417e33..34a37f5e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -33,6 +33,7 @@ def test_firewall_observation(): wildcard_list=["0.0.0.255", "0.0.0.1"], port_list=["HTTP", "DNS"], protocol_list=["TCP"], + include_users=False, ) observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 1edb0442..69d9f106 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,7 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + include_users=False, ) assert host_obs.space["operating_status"] == spaces.Discrete(5) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index c534307f..48d29cfb 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -27,7 +27,7 @@ def test_router_observation(): port_list=["HTTP", "DNS"], protocol_list=["TCP"], ) - router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl) + router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py new file mode 100644 index 00000000..ca5e2543 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -0,0 +1,89 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +@pytest.fixture +def env_with_ssh() -> PrimaiteGymEnv: + """Build data manipulation environment with SSH port open on router.""" + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.agent.flatten_obs = False + router: Router = env.game.simulation.network.get_node_by_hostname("router_1") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + return env + + +def extract_login_numbers_from_obs(obs): + """Traverse the observation dictionary and return number of user sessions for all nodes.""" + login_nums = {} + for node_name, node_obs in obs["NODES"].items(): + login_nums[node_name] = node_obs.get("users") + return login_nums + + +class TestUserObservations: + """Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins.""" + + def test_no_sessions_at_episode_start(self, env_with_ssh): + """Test that all of the login observations start at 0 before any logins occur.""" + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_single_login(self, env_with_ssh: PrimaiteGymEnv): + """Test that performing a remote login increases the remote_sessions observation by 1.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + assert db_srv_logins_obs["local_login"] == 0 + assert db_srv_logins_obs["remote_sessions"] == 1 + for o in logins_obs.values(): # the remaining obs after popping HOST2 + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_logout(self, env_with_ssh: PrimaiteGymEnv): + """Test that remote_sessions observation correctly decreases upon logout.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv): + """Log in from 5 remote places and check that only a max of 3 is shown in the observation.""" + MAX_OBSERVABLE_SESSIONS = 3 + # Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_session_manager.remote_session_timeout_steps = 20 + db_srv.user_session_manager.max_remote_sessions = 5 + node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller") + + for i, node_name in enumerate(node_names): + node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name) + node.terminal._send_remote_login("admin", "admin", "192.168.1.14") + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1) + assert len(db_srv.user_session_manager.remote_sessions) == i + 1 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 83b04832..58783d70 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -72,25 +72,26 @@ def test_uc2_rewards(game_and_agent): request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response - ), + ahi = AgentHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == 1.0 + assert ahi.reward_info == {"connection_attempt_status": "success"} router.acl.remove_rule(position=2) response = game.simulation.apply_request(request) state = game.get_sim_state() + ahi = AgentHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response + ) reward_value = comp.calculate( state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response - ), + last_action_response=ahi, ) assert reward_value == -1.0 + assert ahi.reward_info == {"connection_attempt_status": "failure"} def test_shared_reward():