Merged PR 484: #2769: initial commit of user account actions

## Summary
Adding the Action Space for remote login and remote log out + change password

Updated ray to 2.32.0

## Test process
https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/484?_a=files&path=/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

#2769: initial commit of user account actions

Related work items: #2769
This commit is contained in:
Czar Echavez
2024-08-19 14:31:44 +00:00
committed by Marek Wolan
15 changed files with 1412 additions and 61 deletions

View File

@@ -10,11 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
- Added actions to establish SSH connections, send commands remotely and terminate SSH connections.
- Added actions to change users' passwords.
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
main port they're assigned.
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
- Node observations can now be configured to show the number of active local and remote logins.
### Fixed
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)

View File

@@ -1071,6 +1071,83 @@ class NodeNetworkServiceReconAction(AbstractAction):
]
class NodeAccountsChangePasswordAction(AbstractAction):
"""Action which changes the password for a user."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"UserManager",
"change_password",
username,
current_password,
new_password,
]
class NodeSessionsRemoteLoginAction(AbstractAction):
"""Action which performs a remote session login."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"Terminal",
"ssh_to_remote",
username,
password,
remote_ip,
]
class NodeSessionsRemoteLogoutAction(AbstractAction):
"""Action which performs a remote session logout."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: str, remote_ip: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip]
class NodeSendRemoteCommandAction(AbstractAction):
"""Action which sends a terminal command to a remote node via SSH."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"Terminal",
"send_remote_command",
remote_ip,
{"command": command},
]
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -1122,6 +1199,10 @@ class ActionManager:
"CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction,
"CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction,
"CONFIGURE_DOSBOT": ConfigureDoSBotAction,
"NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction,
"SSH_TO_REMOTE": NodeSessionsRemoteLoginAction,
"SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction,
"NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -10,6 +10,7 @@ from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
@@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list: List[int],
protocol_list: List[str],
num_rules: int,
include_users: bool,
) -> None:
"""
Initialise a firewall observation instance.
@@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
@@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
firewall_state = access_from_nested_dict(state, self.where)
if firewall_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
@@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
},
},
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
include_users=config.include_users,
)

View File

@@ -52,6 +52,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -69,6 +71,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
monitored_traffic: Optional[Dict],
include_num_access: bool,
file_system_requires_scan: bool,
include_users: bool,
) -> None:
"""
Initialise a host observation instance.
@@ -103,10 +106,15 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
If False, the true state is always shown.
:type file_system_requires_scan: bool
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_num_access = include_num_access
self.include_users = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
@@ -165,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -192,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -216,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod
@@ -273,4 +293,5 @@ class HostObservation(AbstractObservation, identifier="HOST"):
monitored_traffic=config.monitored_traffic,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
include_users=config.include_users,
)

View File

@@ -46,6 +46,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""Flag to include the number of accesses."""
file_system_requires_scan: bool = True
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
include_users: Optional[bool] = True
"""If True, report user session information."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
@@ -191,6 +193,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
host_config.include_num_access = config.include_num_access
if host_config.file_system_requires_scan is None:
host_config.file_system_requires_scan = config.file_system_requires_scan
if host_config.include_users is None:
host_config.include_users = config.include_users
for router_config in config.routers:
if router_config.num_ports is None:
@@ -205,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
if router_config.include_users is None:
router_config.include_users = config.include_users
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
@@ -217,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
if firewall_config.include_users is None:
firewall_config.include_users = config.include_users
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]

View File

@@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
ports: List[PortObservation],
num_ports: int,
acl: ACLObservation,
include_users: bool,
) -> None:
"""
Initialise a router observation instance.
@@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
:type num_ports: int
:param acl: ACL observation representing the access control list of the router.
:type acl: ACLObservation
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports: int = num_ports
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
@@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users)

View File

@@ -990,6 +990,7 @@ class UserManager(Service):
if user and user.password == current_password:
user.password = new_password
self.sys_log.info(f"{self.name}: Password changed for {username}")
self._user_session_manager._logout_user(user=user)
return True
self.sys_log.info(f"{self.name}: Password change failed for {username}")
return False
@@ -1027,6 +1028,10 @@ class UserManager(Service):
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
return False
@property
def _user_session_manager(self) -> "UserSessionManager":
return self.software_manager.software["UserSessionManager"] # noqa
class UserSession(SimComponent):
"""
@@ -1260,7 +1265,8 @@ class UserSessionManager(Service):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["active_remote_logins"] = len(self.remote_sessions)
state["current_local_user"] = None if not self.local_session else self.local_session.user.username
state["active_remote_sessions"] = list(self.remote_sessions.keys())
return state
@property
@@ -1435,6 +1441,19 @@ class UserSessionManager(Service):
"""
return self._logout(local=False, remote_session_id=remote_session_id)
def _logout_user(self, user: Union[str, User]) -> bool:
"""End a user session by username or user object."""
if isinstance(user, str):
user = self._user_manager.users[user] # grab user object from username
for sess_id, session in self.remote_sessions.items():
if session.user is user:
self._logout(local=False, remote_session_id=sess_id)
return True
if self.local_user_logged_in and self.local_session.user is user:
self.local_logout()
return True
return False
@property
def local_user_logged_in(self) -> bool:
"""

View File

@@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor
# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication
class TerminalClientConnection(BaseModel):
"""
TerminalClientConnection Class.
@@ -92,7 +94,7 @@ class LocalTerminalConnection(TerminalClientConnection):
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return None
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
return self.parent_terminal.execute(command)
class RemoteTerminalConnection(TerminalClientConnection):
@@ -162,22 +164,6 @@ class Terminal(Service):
def _init_request_manager(self) -> RequestManager:
"""Initialise Request manager."""
rm = super()._init_request_manager()
rm.add_request(
"send",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())),
)
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._process_local_login(username=request[0], password=request[1])
if login:
return RequestResponse(
status="success",
data={
"ip_address": login.ip_address,
},
)
else:
return RequestResponse(status="failure", data={"reason": "Invalid login credentials"})
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
@@ -191,10 +177,34 @@ class Terminal(Service):
else:
return RequestResponse(status="failure", data={})
rm.add_request(
"ssh_to_remote",
request_type=RequestType(func=_remote_login),
)
def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from remote connection."""
ip_address = IPv4Address(request[0])
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = self._disconnect(remote_connection.connection_uuid)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={"reason": "No remote connection held."},
)
rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff))
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
command: str = request[0]
ip_address: IPv4Address = IPv4Address(request[1])
ip_address: IPv4Address = IPv4Address(request[0])
command: str = request[1]["command"]
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
@@ -209,30 +219,11 @@ class Terminal(Service):
data={},
)
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from connection."""
connection_uuid = request[0]
self.parent.user_session_manager.local_logout(connection_uuid)
self._disconnect(connection_uuid)
return RequestResponse(status="success", data={})
rm.add_request(
"Login",
request_type=RequestType(func=_login),
)
rm.add_request(
"Remote Login",
request_type=RequestType(func=_remote_login),
)
rm.add_request(
"Execute",
"send_remote_command",
request_type=RequestType(func=remote_execute_request),
)
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
return rm
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
@@ -280,13 +271,9 @@ class Terminal(Service):
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
if ip_address:
# Assuming that if IP is passed we are connecting to remote
return self._send_remote_login(
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
)
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
else:
return self._process_local_login(username=username, password=password)
@@ -313,6 +300,9 @@ class Terminal(Service):
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id):
self._disconnect(connection_id)
return False
return connection_id in self._connections
def _send_remote_login(
@@ -320,32 +310,24 @@ class Terminal(Service):
username: str,
password: str,
ip_address: IPv4Address,
connection_request_id: str,
connection_request_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Optional[RemoteTerminalConnection]:
"""Send a remote login attempt and connect to Node.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt.
:type: ip_address: IPv4Address
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: connection_request_id: Connection Request ID, if not provided, a new one is generated
:type: connection_request_id: Optional[str]
:param: is_reattempt: True if the request has been reattempted. Default False.
:type: is_reattempt: Optional[bool]
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
"""
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
connection_request_id = connection_request_id or str(uuid4())
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
@@ -360,6 +342,9 @@ class Terminal(Service):
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
return None
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)

View File

@@ -0,0 +1,942 @@
io_settings:
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
sys_log_level: WARNING
game:
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
monitored_traffic:
icmp:
- NONE
tcp:
- DNS
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 1
service_id: 0
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 1
service_id: 0
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 1
service_id: 0
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 1
service_id: 0
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 1
service_id: 0
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 1
service_id: 0
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 1
service_id: 0
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 1
service_id: 0
9: # check database.db file
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
node_id: 0
19:
action: "NODE_SHUTDOWN"
options:
node_id: 0
20:
action: NODE_STARTUP
options:
node_id: 0
21:
action: NODE_RESET
options:
node_id: 0
22:
action: "NODE_OS_SCAN"
options:
node_id: 1
23:
action: "NODE_SHUTDOWN"
options:
node_id: 1
24:
action: NODE_STARTUP
options:
node_id: 1
25:
action: NODE_RESET
options:
node_id: 1
26: # old action num: 18
action: "NODE_OS_SCAN"
options:
node_id: 2
27:
action: "NODE_SHUTDOWN"
options:
node_id: 2
28:
action: NODE_STARTUP
options:
node_id: 2
29:
action: NODE_RESET
options:
node_id: 2
30:
action: "NODE_OS_SCAN"
options:
node_id: 3
31:
action: "NODE_SHUTDOWN"
options:
node_id: 3
32:
action: NODE_STARTUP
options:
node_id: 3
33:
action: NODE_RESET
options:
node_id: 3
34:
action: "NODE_OS_SCAN"
options:
node_id: 4
35:
action: "NODE_SHUTDOWN"
options:
node_id: 4
36:
action: NODE_STARTUP
options:
node_id: 4
37:
action: NODE_RESET
options:
node_id: 4
38:
action: "NODE_OS_SCAN"
options:
node_id: 5
39: # old action num: 19 # shutdown client 1
action: "NODE_SHUTDOWN"
options:
node_id: 5
40: # old action num: 20
action: NODE_STARTUP
options:
node_id: 5
41: # old action num: 21
action: NODE_RESET
options:
node_id: 5
42:
action: "NODE_OS_SCAN"
options:
node_id: 6
43:
action: "NODE_SHUTDOWN"
options:
node_id: 6
44:
action: NODE_STARTUP
options:
node_id: 6
45:
action: NODE_RESET
options:
node_id: 6
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
options:
target_router: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 0
53: # old action num: 29
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 1
54: # old action num: 30
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 2
55: # old action num: 31
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 3
56: # old action num: 32
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 4
57: # old action num: 33
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 5
58: # old action num: 34
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 6
59: # old action num: 35
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 7
60: # old action num: 36
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 8
61: # old action num: 37
action: "ROUTER_ACL_REMOVERULE"
options:
target_router: router_1
position: 9
62: # old action num: 38
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
63: # old action num: 39
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
64: # old action num: 40
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
65: # old action num: 41
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
66: # old action num: 42
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
67: # old action num: 43
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
68: # old action num: 44
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
69: # old action num: 45
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
70: # old action num: 46
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
71: # old action num: 47
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
72: # old action num: 48
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
73: # old action num: 49
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
74: # old action num: 50
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
75: # old action num: 51
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
76: # old action num: 52
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
77: # old action num: 53
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
agent_settings:
flatten_obs: true
action_masking: true
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: router_1
type: router
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- hostname: switch_1
type: switch
num_ports: 8
- hostname: switch_2
type: switch
num_ports: 8
- hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: WebServer
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: FTPClient
- hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: FTPServer
- hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
network_interfaces:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- type: DNSClient
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: web_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_hostname: database_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 2

View File

@@ -458,6 +458,10 @@ def game_and_agent():
{"type": "HOST_NIC_DISABLE"},
{"type": "NETWORK_PORT_ENABLE"},
{"type": "NETWORK_PORT_DISABLE"},
{"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"},
{"type": "SSH_TO_REMOTE"},
{"type": "SESSIONS_REMOTE_LOGOFF"},
{"type": "NODE_SEND_REMOTE_COMMAND"},
]
action_space = ActionManager(

View File

@@ -0,0 +1,166 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4)
return (game, agent)
def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"SSH_TO_REMOTE",
{
"node_id": 0,
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
connection_established = False
for conn_str, conn_obj in client_1.terminal.connections.items():
conn_obj: RemoteTerminalConnection
if conn_obj.ip_address == server_1.network_interface[1].ip_address:
connection_established = True
if not connection_established:
pytest.fail("Remote SSH connection could not be established")
def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"SSH_TO_REMOTE",
{
"node_id": 0,
"username": "user123",
"password": "wrong_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "failure"
connection_established = False
for conn_str, conn_obj in client_1.terminal.connections.items():
conn_obj: RemoteTerminalConnection
if conn_obj.ip_address == server_1.network_interface[1].ip_address:
connection_established = True
if connection_established:
pytest.fail("Remote SSH connection was established despite wrong password")
def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_um: UserManager = server_1.software_manager.software["UserManager"]
server_1_um.add_user("user123", "password", is_admin=True)
action = (
"NODE_ACCOUNTS_CHANGE_PASSWORD",
{
"node_id": 1, # server_1
"username": "user123",
"current_password": "password",
"new_password": "different_password",
},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
assert server_1_um.users["user123"].password == "different_password"
def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
# create a new user account on server_1 that will be logged into remotely
server_1_usm: UserManager = server_1.software_manager.software["UserManager"]
server_1_usm.add_user("user123", "password", is_admin=True)
# Log in remotely
action = (
"SSH_TO_REMOTE",
{
"node_id": 0,
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
game.step()
# Change password
action = (
"NODE_ACCOUNTS_CHANGE_PASSWORD",
{
"node_id": 1, # server_1
"username": "user123",
"current_password": "password",
"new_password": "different_password",
},
)
agent.store_action(action)
game.step()
# Assert that the user cannot execute an action
action = (
"NODE_SEND_REMOTE_COMMAND",
{
"node_id": 0,
"remote_ip": str(server_1.network_interface[1].ip_address),
"command": ["file_system", "create", "file", "folder123", "doggo.pdf", False],
},
)
agent.store_action(action)
game.step()
assert server_1.file_system.get_folder("folder123") is None
assert server_1.file_system.get_file("folder123", "doggo.pdf") is None

View File

@@ -33,6 +33,7 @@ def test_firewall_observation():
wildcard_list=["0.0.0.255", "0.0.0.1"],
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
include_users=False,
)
observation = firewall_observation.observe(firewall.describe_state())

View File

@@ -39,6 +39,7 @@ def test_host_observation(simulation):
folders=[],
network_interfaces=[],
file_system_requires_scan=True,
include_users=False,
)
assert host_obs.space["operating_status"] == spaces.Discrete(5)

View File

@@ -27,7 +27,7 @@ def test_router_observation():
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
)
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl)
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False)
# Observe the state using the RouterObservation instance
observed_output = router_observation.observe(router.describe_state())

View File

@@ -0,0 +1,89 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests import TEST_ASSETS_ROOT
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
@pytest.fixture
def env_with_ssh() -> PrimaiteGymEnv:
"""Build data manipulation environment with SSH port open on router."""
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
env.agent.flatten_obs = False
router: Router = env.game.simulation.network.get_node_by_hostname("router_1")
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3)
return env
def extract_login_numbers_from_obs(obs):
"""Traverse the observation dictionary and return number of user sessions for all nodes."""
login_nums = {}
for node_name, node_obs in obs["NODES"].items():
login_nums[node_name] = node_obs.get("users")
return login_nums
class TestUserObservations:
"""Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins."""
def test_no_sessions_at_episode_start(self, env_with_ssh):
"""Test that all of the login observations start at 0 before any logins occur."""
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
for o in logins_obs.values():
assert o["local_login"] == 0
assert o["remote_sessions"] == 0
def test_single_login(self, env_with_ssh: PrimaiteGymEnv):
"""Test that performing a remote login increases the remote_sessions observation by 1."""
client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1")
client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
assert db_srv_logins_obs["local_login"] == 0
assert db_srv_logins_obs["remote_sessions"] == 1
for o in logins_obs.values(): # the remaining obs after popping HOST2
assert o["local_login"] == 0
assert o["remote_sessions"] == 0
def test_logout(self, env_with_ssh: PrimaiteGymEnv):
"""Test that remote_sessions observation correctly decreases upon logout."""
client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1")
client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh
db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server")
db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
for o in logins_obs.values():
assert o["local_login"] == 0
assert o["remote_sessions"] == 0
def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv):
"""Log in from 5 remote places and check that only a max of 3 is shown in the observation."""
MAX_OBSERVABLE_SESSIONS = 3
# Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server")
db_srv.user_session_manager.remote_session_timeout_steps = 20
db_srv.user_session_manager.max_remote_sessions = 5
node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller")
for i, node_name in enumerate(node_names):
node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name)
node.terminal._send_remote_login("admin", "admin", "192.168.1.14")
obs, *_ = env_with_ssh.step(0)
logins_obs = extract_login_numbers_from_obs(obs)
db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server
assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1)
assert len(db_srv.user_session_manager.remote_sessions) == i + 1